diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py b/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py index b1930c8..fc32d9e 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py @@ -69,6 +69,35 @@ _LENS2_STD_STATS_PATH = "/lzepedanunez/data/lens2/stats/lens2_std_stats_all_variables_240x121_lonlat_1961-2000.zarr" +def maybe_expand_dims( + array: np.ndarray, + allowed_dims: tuple[int, ...], + trigger_expand_dims: int, + axis: int = -1, +) -> np.ndarray: + """Expands the dimensions of a numpy array if necessary. + + Args: + array: The numpy array to be possibly expanded. + allowed_dims: The dimensions that the array can have, raise an error + otherwise. + trigger_expand_dims: The dimension that triggers the expansion. + axis: The axis in which the extra dimension is added. + + Returns: + The array possibly expanded if its dimension is trigger_expand_dims. + """ + ndim = array.ndim + if ndim not in allowed_dims: + raise ValueError( + f"The array has {ndim} dimensions, but it should have one of the" + f" dimensions {allowed_dims}" + ) + if ndim == trigger_expand_dims: + array = np.expand_dims(array, axis=axis) + return array + + class UnpairedDataLoader: """Unpaired dataloader for loading samples from two distributions."""