Skip to content

Commit

Permalink
Merge pull request #437 from JuliaKukulies/bulk_stats_bug_fix
Browse files Browse the repository at this point in the history
Bulk stats bug fix
  • Loading branch information
JuliaKukulies authored Aug 15, 2024
2 parents 67f86cf + 2ede64b commit dadad41
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 51 deletions.
1 change: 1 addition & 0 deletions tobac/tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for tobac.utils.decorators
"""

import numpy as np
import pandas as pd
import xarray as xr
Expand Down
14 changes: 9 additions & 5 deletions tobac/tests/test_utils_bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import tobac.testing as tb_test


def test_bulk_statistics():
@pytest.mark.parametrize(
"id_column, index", [("feature", [1]), ("feature_id", [1]), ("cell", [1])]
)
def test_bulk_statistics(id_column, index):
"""
Test to assure that bulk statistics for identified features are computed as expected.
Expand Down Expand Up @@ -46,10 +49,10 @@ def test_bulk_statistics():
)

#### checks

out_df = out_df.rename(columns={"feature": id_column})
# assure that bulk statistics in postprocessing give same result
out_segmentation = tb_utils.get_statistics_from_mask(
out_df, out_seg_mask, test_data_iris, statistic=stats
out_df, out_seg_mask, test_data_iris, statistic=stats, id_column=id_column
)
assert out_segmentation.equals(out_df)

Expand Down Expand Up @@ -86,11 +89,12 @@ def test_bulk_statistics():
)

##### checks #####

out_df = out_df.rename(columns={"feature": id_column})
# assure that bulk statistics in postprocessing give same result
out_segmentation = tb_utils.get_statistics_from_mask(
out_df, out_seg_mask, test_data_iris, statistic=stats
out_df, out_seg_mask, test_data_iris, statistic=stats, id_column=id_column
)

assert out_segmentation.equals(out_df)

# assure that column names in new dataframe correspond to keys in statistics dictionary
Expand Down
14 changes: 10 additions & 4 deletions tobac/utils/bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_statistics(
# mask must contain positive values to calculate statistics
if np.any(labels > 0):
if index is None:
index = features.feature.to_numpy()
index = features[id_column].to_numpy().astype(int)
else:
# get the statistics only for specified feature objects
if np.max(index) > np.max(labels):
Expand Down Expand Up @@ -266,10 +266,16 @@ def get_statistics_from_mask(
Updated feature dataframe with bulk statistics for each feature saved in a new column
"""
# warning when feature labels are not unique in dataframe
if not features.feature.is_unique:
raise logging.warning(
if not features[id_column].is_unique:
logging.warning(
"Feature labels are not unique which may cause unexpected results for the computation of bulk statistics."
)
# extra warning when feature labels are not unique in timestep
uniques = features.groupby("time")[id_column].value_counts().values
if not uniques[uniques > 1].size == 0:
logging.warning(
"Note that non-unique feature labels occur also in the same timestep. This likely causes unexpected results for the computation of bulk statistics."
)

if collapse_dim is not None:
if isinstance(collapse_dim, str):
Expand Down Expand Up @@ -299,7 +305,7 @@ def get_statistics_from_mask(

# make sure that the labels in the segmentation mask exist in feature dataframe
if (
np.intersect1d(np.unique(segmentation_mask_t), features_t.feature).size
np.intersect1d(np.unique(segmentation_mask_t), features_t[id_column]).size
> np.unique(segmentation_mask_t).size
):
raise ValueError(
Expand Down
110 changes: 68 additions & 42 deletions tobac/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ def _conv_kwargs_irispandas_to_xarray(conv_kwargs: dict):
"""
return {
key: convert_cube_to_dataarray(arg)
if isinstance(arg, iris.cube.Cube)
else arg.to_xarray()
if isinstance(arg, pd.DataFrame)
else arg
key: (
convert_cube_to_dataarray(arg)
if isinstance(arg, iris.cube.Cube)
else arg.to_xarray()
if isinstance(arg, pd.DataFrame)
else arg
)
for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values())
}

Expand Down Expand Up @@ -118,11 +120,13 @@ def _conv_kwargs_xarray_to_irispandas(conv_kwargs: dict):
iris cubes
"""
return {
key: xr.DataArray.to_iris(arg)
if isinstance(arg, xr.DataArray)
else arg.to_dataframe()
if isinstance(arg, xr.Dataset)
else arg
key: (
xr.DataArray.to_iris(arg)
if isinstance(arg, xr.DataArray)
else arg.to_dataframe()
if isinstance(arg, xr.Dataset)
else arg
)
for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values())
}

Expand Down Expand Up @@ -166,9 +170,11 @@ def wrapper(*args, **kwargs):
# print("converting iris to xarray and back")
args = tuple(
[
convert_cube_to_dataarray(arg)
if type(arg) == iris.cube.Cube
else arg
(
convert_cube_to_dataarray(arg)
if type(arg) == iris.cube.Cube
else arg
)
for arg in args
]
)
Expand All @@ -179,9 +185,11 @@ def wrapper(*args, **kwargs):
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.to_iris(output_item)
if type(output_item) == xarray.DataArray
else output_item
(
xarray.DataArray.to_iris(output_item)
if type(output_item) == xarray.DataArray
else output_item
)
for output_item in output
]
)
Expand Down Expand Up @@ -241,9 +249,11 @@ def wrapper(*args, **kwargs):
# print("converting xarray to iris and back")
args = tuple(
[
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg
(
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg
)
for arg in args
]
)
Expand All @@ -257,9 +267,11 @@ def wrapper(*args, **kwargs):
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.from_iris(output_item)
if type(output_item) == iris.cube.Cube
else output_item
(
xarray.DataArray.from_iris(output_item)
if type(output_item) == iris.cube.Cube
else output_item
)
for output_item in output
]
)
Expand Down Expand Up @@ -325,11 +337,13 @@ def wrapper(*args, **kwargs):
# print("converting iris to xarray and back")
args = tuple(
[
convert_cube_to_dataarray(arg)
if type(arg) == iris.cube.Cube
else arg.to_xarray()
if type(arg) == pd.DataFrame
else arg
(
convert_cube_to_dataarray(arg)
if type(arg) == iris.cube.Cube
else arg.to_xarray()
if type(arg) == pd.DataFrame
else arg
)
for arg in args
]
)
Expand All @@ -339,11 +353,15 @@ def wrapper(*args, **kwargs):
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.to_iris(output_item)
if type(output_item) == xarray.DataArray
else output_item.to_dataframe()
if type(output_item) == xarray.Dataset
else output_item
(
xarray.DataArray.to_iris(output_item)
if type(output_item) == xarray.DataArray
else (
output_item.to_dataframe()
if type(output_item) == xarray.Dataset
else output_item
)
)
for output_item in output
]
)
Expand Down Expand Up @@ -415,11 +433,15 @@ def wrapper(*args, **kwargs):
# print("converting xarray to iris and back")
args = tuple(
[
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg.to_dataframe()
if type(arg) == xarray.Dataset
else arg
(
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else (
arg.to_dataframe()
if type(arg) == xarray.Dataset
else arg
)
)
for arg in args
]
)
Expand All @@ -433,11 +455,15 @@ def wrapper(*args, **kwargs):
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.from_iris(output_item)
if type(output_item) == iris.cube.Cube
else output_item.to_xarray()
if type(output_item) == pd.DataFrame
else output_item
(
xarray.DataArray.from_iris(output_item)
if type(output_item) == iris.cube.Cube
else (
output_item.to_xarray()
if type(output_item) == pd.DataFrame
else output_item
)
)
for output_item in output
]
)
Expand Down
1 change: 1 addition & 0 deletions tobac/utils/periodic_boundaries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities for handling indexing and distance calculation with periodic boundaries
"""

from __future__ import annotations
import functools

Expand Down

0 comments on commit dadad41

Please sign in to comment.