From 8b450dfcc87715cc028f76990c9cfc41fbfc2a8d Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Mon, 25 Nov 2024 10:04:13 -0800 Subject: [PATCH] Improve error handling for datetime values when `apply_log = True` for `InterRowMSAS` (#673) --- .../statistical/inter_row_msas.py | 8 +++++++ .../statistical/test_inter_row_msas.py | 22 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py index 200b58fa..4755621d 100644 --- a/sdmetrics/column_pairs/statistical/inter_row_msas.py +++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py @@ -47,6 +47,14 @@ def _validate_inputs(real_data, synthetic_data, n_rows_diff, apply_log): @staticmethod def _apply_log(real_values, synthetic_values, apply_log): if apply_log: + if pd.api.types.is_datetime64_any_dtype( + real_values + ) or pd.api.types.is_datetime64_any_dtype(synthetic_values): + raise TypeError( + 'Cannot compute log for datetime columns. ' + "Please set 'apply_log' to False to use this metric." + ) + num_invalid = sum(x <= 0 for x in pd.concat((real_values, synthetic_values))) if num_invalid: warnings.warn( diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py index a88e375f..647b6569 100644 --- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py +++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pandas as pd import pytest @@ -96,6 +98,26 @@ def test_compute_with_log_warning(self): assert str(warning_info[0].message) == expected_message assert score == 0 + def test_compute_with_log_datetime(self): + """Test it crashes for logs of datetime values.""" + # Setup + real_keys = pd.Series(['id1', 'id1']) + real_values = pd.Series([datetime(2020, 10, 1), datetime(2020, 10, 1)]) + synthetic_keys = pd.Series(['id2', 'id2']) + synthetic_values = pd.Series([datetime(2020, 10, 1), datetime(2020, 10, 1)]) + + # Run and Assert + err_msg = ( + 'Cannot compute log for datetime columns. ' + "Please set 'apply_log' to False to use this metric." + ) + with pytest.raises(TypeError, match=err_msg): + InterRowMSAS.compute( + real_data=(real_keys, real_values), + synthetic_data=(synthetic_keys, synthetic_values), + apply_log=True, + ) + def test_compute_different_n_rows_diff(self): """Test it with different n_rows_diff.""" # Setup