diff --git a/tobac/tests/test_decorators.py b/tobac/tests/test_decorators.py index 01a3a0ad..039babde 100644 --- a/tobac/tests/test_decorators.py +++ b/tobac/tests/test_decorators.py @@ -1,6 +1,7 @@ """ Tests for tobac.utils.decorators """ + import numpy as np import pandas as pd import xarray as xr diff --git a/tobac/tests/test_utils_bulk_statistics.py b/tobac/tests/test_utils_bulk_statistics.py index 4967a779..1db036b0 100644 --- a/tobac/tests/test_utils_bulk_statistics.py +++ b/tobac/tests/test_utils_bulk_statistics.py @@ -8,7 +8,10 @@ import tobac.testing as tb_test -def test_bulk_statistics(): +@pytest.mark.parametrize( + "id_column, index", [("feature", [1]), ("feature_id", [1]), ("cell", [1])] +) +def test_bulk_statistics(id_column, index): """ Test to assure that bulk statistics for identified features are computed as expected. @@ -46,10 +49,10 @@ def test_bulk_statistics(): ) #### checks - + out_df = out_df.rename(columns={"feature": id_column}) # assure that bulk statistics in postprocessing give same result out_segmentation = tb_utils.get_statistics_from_mask( - out_df, out_seg_mask, test_data_iris, statistic=stats + out_df, out_seg_mask, test_data_iris, statistic=stats, id_column=id_column ) assert out_segmentation.equals(out_df) @@ -86,11 +89,12 @@ def test_bulk_statistics(): ) ##### checks ##### - + out_df = out_df.rename(columns={"feature": id_column}) # assure that bulk statistics in postprocessing give same result out_segmentation = tb_utils.get_statistics_from_mask( - out_df, out_seg_mask, test_data_iris, statistic=stats + out_df, out_seg_mask, test_data_iris, statistic=stats, id_column=id_column ) + assert out_segmentation.equals(out_df) # assure that column names in new dataframe correspond to keys in statistics dictionary diff --git a/tobac/utils/bulk_statistics.py b/tobac/utils/bulk_statistics.py index cb9ce0cf..2dbcf6f2 100644 --- a/tobac/utils/bulk_statistics.py +++ b/tobac/utils/bulk_statistics.py @@ -122,7 +122,7 @@ def get_statistics( # mask must contain positive values to calculate statistics if np.any(labels > 0): if index is None: - index = features.feature.to_numpy() + index = features[id_column].to_numpy().astype(int) else: # get the statistics only for specified feature objects if np.max(index) > np.max(labels): @@ -266,10 +266,16 @@ def get_statistics_from_mask( Updated feature dataframe with bulk statistics for each feature saved in a new column """ # warning when feature labels are not unique in dataframe - if not features.feature.is_unique: - raise logging.warning( + if not features[id_column].is_unique: + logging.warning( "Feature labels are not unique which may cause unexpected results for the computation of bulk statistics." ) + # extra warning when feature labels are not unique in timestep + uniques = features.groupby("time")[id_column].value_counts().values + if not uniques[uniques > 1].size == 0: + logging.warning( + "Note that non-unique feature labels occur also in the same timestep. This likely causes unexpected results for the computation of bulk statistics." + ) if collapse_dim is not None: if isinstance(collapse_dim, str): @@ -299,7 +305,7 @@ def get_statistics_from_mask( # make sure that the labels in the segmentation mask exist in feature dataframe if ( - np.intersect1d(np.unique(segmentation_mask_t), features_t.feature).size + np.intersect1d(np.unique(segmentation_mask_t), features_t[id_column]).size > np.unique(segmentation_mask_t).size ): raise ValueError( diff --git a/tobac/utils/decorators.py b/tobac/utils/decorators.py index 8bc6657f..90e600b5 100644 --- a/tobac/utils/decorators.py +++ b/tobac/utils/decorators.py @@ -72,11 +72,13 @@ def _conv_kwargs_irispandas_to_xarray(conv_kwargs: dict): """ return { - key: convert_cube_to_dataarray(arg) - if isinstance(arg, iris.cube.Cube) - else arg.to_xarray() - if isinstance(arg, pd.DataFrame) - else arg + key: ( + convert_cube_to_dataarray(arg) + if isinstance(arg, iris.cube.Cube) + else arg.to_xarray() + if isinstance(arg, pd.DataFrame) + else arg + ) for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) } @@ -118,11 +120,13 @@ def _conv_kwargs_xarray_to_irispandas(conv_kwargs: dict): iris cubes """ return { - key: xr.DataArray.to_iris(arg) - if isinstance(arg, xr.DataArray) - else arg.to_dataframe() - if isinstance(arg, xr.Dataset) - else arg + key: ( + xr.DataArray.to_iris(arg) + if isinstance(arg, xr.DataArray) + else arg.to_dataframe() + if isinstance(arg, xr.Dataset) + else arg + ) for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) } @@ -166,9 +170,11 @@ def wrapper(*args, **kwargs): # print("converting iris to xarray and back") args = tuple( [ - convert_cube_to_dataarray(arg) - if type(arg) == iris.cube.Cube - else arg + ( + convert_cube_to_dataarray(arg) + if type(arg) == iris.cube.Cube + else arg + ) for arg in args ] ) @@ -179,9 +185,11 @@ def wrapper(*args, **kwargs): if type(output) == tuple: output = tuple( [ - xarray.DataArray.to_iris(output_item) - if type(output_item) == xarray.DataArray - else output_item + ( + xarray.DataArray.to_iris(output_item) + if type(output_item) == xarray.DataArray + else output_item + ) for output_item in output ] ) @@ -241,9 +249,11 @@ def wrapper(*args, **kwargs): # print("converting xarray to iris and back") args = tuple( [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg + ( + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg + ) for arg in args ] ) @@ -257,9 +267,11 @@ def wrapper(*args, **kwargs): if type(output) == tuple: output = tuple( [ - xarray.DataArray.from_iris(output_item) - if type(output_item) == iris.cube.Cube - else output_item + ( + xarray.DataArray.from_iris(output_item) + if type(output_item) == iris.cube.Cube + else output_item + ) for output_item in output ] ) @@ -325,11 +337,13 @@ def wrapper(*args, **kwargs): # print("converting iris to xarray and back") args = tuple( [ - convert_cube_to_dataarray(arg) - if type(arg) == iris.cube.Cube - else arg.to_xarray() - if type(arg) == pd.DataFrame - else arg + ( + convert_cube_to_dataarray(arg) + if type(arg) == iris.cube.Cube + else arg.to_xarray() + if type(arg) == pd.DataFrame + else arg + ) for arg in args ] ) @@ -339,11 +353,15 @@ def wrapper(*args, **kwargs): if type(output) == tuple: output = tuple( [ - xarray.DataArray.to_iris(output_item) - if type(output_item) == xarray.DataArray - else output_item.to_dataframe() - if type(output_item) == xarray.Dataset - else output_item + ( + xarray.DataArray.to_iris(output_item) + if type(output_item) == xarray.DataArray + else ( + output_item.to_dataframe() + if type(output_item) == xarray.Dataset + else output_item + ) + ) for output_item in output ] ) @@ -415,11 +433,15 @@ def wrapper(*args, **kwargs): # print("converting xarray to iris and back") args = tuple( [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg.to_dataframe() - if type(arg) == xarray.Dataset - else arg + ( + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else ( + arg.to_dataframe() + if type(arg) == xarray.Dataset + else arg + ) + ) for arg in args ] ) @@ -433,11 +455,15 @@ def wrapper(*args, **kwargs): if type(output) == tuple: output = tuple( [ - xarray.DataArray.from_iris(output_item) - if type(output_item) == iris.cube.Cube - else output_item.to_xarray() - if type(output_item) == pd.DataFrame - else output_item + ( + xarray.DataArray.from_iris(output_item) + if type(output_item) == iris.cube.Cube + else ( + output_item.to_xarray() + if type(output_item) == pd.DataFrame + else output_item + ) + ) for output_item in output ] ) diff --git a/tobac/utils/periodic_boundaries.py b/tobac/utils/periodic_boundaries.py index 71ecb38e..e230aca5 100644 --- a/tobac/utils/periodic_boundaries.py +++ b/tobac/utils/periodic_boundaries.py @@ -1,5 +1,6 @@ """Utilities for handling indexing and distance calculation with periodic boundaries """ + from __future__ import annotations import functools