From ef825d98c7b14401c449bba4a7b42b5ca5483dcb Mon Sep 17 00:00:00 2001
From: Felipe Alex Hofmann <fealho@gmail.com>
Date: Fri, 22 Nov 2024 09:59:38 -0800
Subject: [PATCH] Improve warning handling for non-positive values when
 `apply_log = True` for `InterRowMSAS` (#671)

---
 .../statistical/inter_row_msas.py             | 99 +++++++++++--------
 .../statistical/test_inter_row_msas.py        | 25 +++++
 2 files changed, 84 insertions(+), 40 deletions(-)

diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py
index eea77f06..200b58fa 100644
--- a/sdmetrics/column_pairs/statistical/inter_row_msas.py
+++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py
@@ -29,7 +29,61 @@ class InterRowMSAS:
     max_value = 1.0
 
     @staticmethod
-    def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
+    def _validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log):
+        for data in [real_data, synthetic_data]:
+            if (
+                not isinstance(data, tuple)
+                or len(data) != 2
+                or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
+            ):
+                raise ValueError('The data must be a tuple of two pandas series.')
+
+        if not isinstance(n_rows_diff, int) or n_rows_diff < 1:
+            raise ValueError("'n_rows_diff' must be an integer greater than zero.")
+
+        if not isinstance(apply_log, bool):
+            raise ValueError("'apply_log' must be a boolean.")
+
+    @staticmethod
+    def _apply_log(real_values, synthetic_values, apply_log):
+        if apply_log:
+            num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values)))
+            if num_invalid:
+                warnings.warn(
+                    f'There are {num_invalid} non-positive values in your data, which cannot be '
+                    "used with log. Consider changing 'apply_log' to False for a better result."
+                )
+            with warnings.catch_warnings():
+                warnings.filterwarnings('ignore', message='.*encountered in log')
+                real_values = np.log(real_values)
+                synthetic_values = np.log(synthetic_values)
+
+        return real_values, synthetic_values
+
+    @staticmethod
+    def _calculate_differences(keys, values, n_rows_diff, data_name):
+        grouped = values.groupby(keys)
+        group_sizes = grouped.size()
+
+        num_invalid_groups = len(group_sizes[group_sizes <= n_rows_diff])
+        if num_invalid_groups > 0:
+            warnings.warn(
+                f"n_rows_diff '{n_rows_diff}' is greater than the "
+                f'size of {num_invalid_groups} sequence keys in {data_name}.'
+            )
+
+        def diff_func(group):
+            if len(group) <= n_rows_diff:
+                return np.nan
+            group = group.to_numpy()
+            return np.mean(group[n_rows_diff:] - group[:-n_rows_diff])
+
+        with warnings.catch_warnings():
+            warnings.filterwarnings('ignore', message='invalid value encountered in.*')
+            return grouped.apply(diff_func)
+
+    @classmethod
+    def compute(cls, real_data, synthetic_data, n_rows_diff=1, apply_log=False):
         """Compute this metric.
 
         This metric compares the inter-row differences of sequences in the real data
@@ -58,48 +112,13 @@ def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
             float:
                 The similarity score between the real and synthetic data distributions.
         """
-        for data in [real_data, synthetic_data]:
-            if (
-                not isinstance(data, tuple)
-                or len(data) != 2
-                or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
-            ):
-                raise ValueError('The data must be a tuple of two pandas series.')
-
-        if not isinstance(n_rows_diff, int) or n_rows_diff < 1:
-            raise ValueError("'n_rows_diff' must be an integer greater than zero.")
-
-        if not isinstance(apply_log, bool):
-            raise ValueError("'apply_log' must be a boolean.")
-
+        cls._validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log)
         real_keys, real_values = real_data
         synthetic_keys, synthetic_values = synthetic_data
+        real_values, synthetic_values = cls._apply_log(real_values, synthetic_values, apply_log)
 
-        if apply_log:
-            real_values = np.log(real_values)
-            synthetic_values = np.log(synthetic_values)
-
-        def calculate_differences(keys, values, n_rows_diff, data_name):
-            group_sizes = values.groupby(keys).size()
-            num_invalid_groups = group_sizes[group_sizes <= n_rows_diff].count()
-            if num_invalid_groups > 0:
-                warnings.warn(
-                    f"n_rows_diff '{n_rows_diff}' is greater than the "
-                    f'size of {num_invalid_groups} sequence keys in {data_name}.'
-                )
-
-            differences = values.groupby(keys).apply(
-                lambda group: np.mean(
-                    group.to_numpy()[n_rows_diff:] - group.to_numpy()[:-n_rows_diff]
-                )
-                if len(group) > n_rows_diff
-                else np.nan
-            )
-
-            return pd.Series(differences)
-
-        real_diff = calculate_differences(real_keys, real_values, n_rows_diff, 'real_data')
-        synthetic_diff = calculate_differences(
+        real_diff = cls._calculate_differences(real_keys, real_values, n_rows_diff, 'real_data')
+        synthetic_diff = cls._calculate_differences(
             synthetic_keys, synthetic_values, n_rows_diff, 'synthetic_data'
         )
 
diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py
index 9a3552db..a88e375f 100644
--- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py
+++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py
@@ -71,6 +71,31 @@ def test_compute_with_log(self):
         # Assert
         assert score == 1
 
+    def test_compute_with_log_warning(self):
+        """Test it warns when negative values are present and apply_log is True."""
+        # Setup
+        real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
+        real_values = pd.Series([1, 1.4, 4, -1, 16, -10])
+        synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
+        synthetic_values = pd.Series([1, 2, -4, 8, 16, 30])
+
+        # Run
+        with pytest.warns(UserWarning) as warning_info:
+            score = InterRowMSAS.compute(
+                real_data=(real_keys, real_values),
+                synthetic_data=(synthetic_keys, synthetic_values),
+                apply_log=True,
+            )
+
+        # Assert
+        expected_message = (
+            'There are 3 non-positive values in your data, which cannot be used with log. '
+            "Consider changing 'apply_log' to False for a better result."
+        )
+        assert len(warning_info) == 1
+        assert str(warning_info[0].message) == expected_message
+        assert score == 0
+
     def test_compute_different_n_rows_diff(self):
         """Test it with different n_rows_diff."""
         # Setup