Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The ContingencySimilarity metric should be able to discretize continuous columns #702

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions sdmetrics/column_pairs/statistical/contingency_similarity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Contingency Similarity Metric."""

import pandas as pd

from sdmetrics.column_pairs.base import ColumnPairsMetric
from sdmetrics.goal import Goal
from sdmetrics.utils import discretize_column


class ContingencySimilarity(ColumnPairsMetric):
Expand All @@ -23,23 +26,57 @@ class ContingencySimilarity(ColumnPairsMetric):
min_value = 0.0
max_value = 1.0

@staticmethod
def _validate_inputs(real_data, synthetic_data, continuous_column_names, num_discrete_bins):
for data in [real_data, synthetic_data]:
if not isinstance(data, pd.DataFrame) or len(data.columns) != 2:
raise ValueError('The data must be a pandas DataFrame with two columns.')

if set(real_data.columns) != set(synthetic_data.columns):
raise ValueError('The columns in the real and synthetic data must match.')

if continuous_column_names is not None:
bad_continuous_columns = "' ,'".join([
column for column in continuous_column_names if column not in real_data.columns
])
if bad_continuous_columns:
raise ValueError(
f"Continuous column(s) '{bad_continuous_columns}' not found in the data."
)

if not isinstance(num_discrete_bins, int) or num_discrete_bins <= 0:
raise ValueError('`num_discrete_bins` must be an integer greater than zero.')

@classmethod
def compute(cls, real_data, synthetic_data):
def compute(cls, real_data, synthetic_data, continuous_column_names=None, num_discrete_bins=10):
"""Compare the contingency similarity of two discrete columns.

Args:
real_data (Union[numpy.ndarray, pandas.Series]):
real_data (pd.DataFrame):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.Series]):
synthetic_data (pd.DataFrame):
The values from the synthetic dataset.
continuous_column_names (list[str], optional):
The list of columns to discretize before running the metric. The column names in
this list should match the column names in the real and synthetic data. Defaults
to ``None``.
num_discrete_bins (int, optional):
The number of bins to create for the continuous columns. Defaults to 10.

Returns:
float:
The contingency similarity of the two columns.
"""
cls._validate_inputs(real_data, synthetic_data, continuous_column_names, num_discrete_bins)
columns = real_data.columns[:2]
real = real_data[columns]
synthetic = synthetic_data[columns]
if continuous_column_names is not None:
for column in continuous_column_names:
real[column], synthetic[column] = discretize_column(
real[column], synthetic[column], num_discrete_bins=num_discrete_bins
)

contingency_real = real.groupby(list(columns), dropna=False).size() / len(real)
contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
synthetic
Expand Down
5 changes: 2 additions & 3 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdmetrics.utils import (
discretize_column,
get_alternate_keys,
get_columns_from_metadata,
get_type_from_column_meta,
Expand Down Expand Up @@ -116,9 +117,7 @@ def discretize_table_data(real_data, synthetic_data, metadata):
real_col = pd.to_numeric(real_col)
synthetic_col = pd.to_numeric(synthetic_col)

bin_edges = np.histogram_bin_edges(real_col.dropna())
binned_real_col = np.digitize(real_col, bins=bin_edges)
binned_synthetic_col = np.digitize(synthetic_col, bins=bin_edges)
binned_real_col, binned_synthetic_col = discretize_column(real_col, synthetic_col)

binned_real[column_name] = binned_real_col
binned_synthetic[column_name] = binned_synthetic_col
Expand Down
22 changes: 22 additions & 0 deletions sdmetrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,28 @@ def is_datetime(data):
)


def discretize_column(real_column, synthetic_column, num_discrete_bins=10):
"""Discretize a real and synthetic column.

Args:
real_column (pd.Series):
The real column.
synthetic_column (pd.Series):
The synthetic column.
num_discrete_bins (int, optional):
The number of bins to create. Defaults to 10.

Returns:
tuple(pd.Series, pd.Series):
The discretized real and synthetic columns.
"""
bin_edges = np.histogram_bin_edges(real_column.dropna(), bins=num_discrete_bins)
bin_edges[0], bin_edges[-1] = -np.inf, np.inf
binned_real_column = np.digitize(real_column, bins=bin_edges)
binned_synthetic_column = np.digitize(synthetic_column, bins=bin_edges)
return binned_real_column, binned_synthetic_column


class HyperTransformer:
"""HyperTransformer class.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from unittest.mock import patch

import pandas as pd
Expand All @@ -7,6 +8,59 @@


class TestContingencySimilarity:
def test__validate_inputs(self):
"""Test the ``_validate_inputs`` method."""
# Setup
bad_data = pd.Series(range(5))
real_data = pd.DataFrame({'col1': range(10), 'col2': range(10, 20)})
bad_synthetic_data = pd.DataFrame({'bad_column': range(10), 'col2': range(10)})
synthetic_data = pd.DataFrame({'col1': range(5), 'col2': range(5)})
bad_continous_columns = ['col1', 'missing_col']
bad_num_discrete_bins = -1

# Run and Assert
expected_bad_data = re.escape('The data must be a pandas DataFrame with two columns.')
with pytest.raises(ValueError, match=expected_bad_data):
ContingencySimilarity._validate_inputs(
real_data=bad_data,
synthetic_data=bad_data,
continuous_column_names=None,
num_discrete_bins=10,
)

expected_mismatch_columns_error = re.escape(
'The columns in the real and synthetic data must match.'
)
with pytest.raises(ValueError, match=expected_mismatch_columns_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=bad_synthetic_data,
continuous_column_names=None,
num_discrete_bins=10,
)

expected_bad_continous_column_error = re.escape(
"Continuous column(s) 'missing_col' not found in the data."
)
with pytest.raises(ValueError, match=expected_bad_continous_column_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=synthetic_data,
continuous_column_names=bad_continous_columns,
num_discrete_bins=10,
)

expected_bad_num_discrete_bins_error = re.escape(
'`num_discrete_bins` must be an integer greater than zero.'
)
with pytest.raises(ValueError, match=expected_bad_num_discrete_bins_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=synthetic_data,
continuous_column_names=['col1'],
num_discrete_bins=bad_num_discrete_bins,
)

def test_compute(self):
"""Test the ``compute`` method.

Expand All @@ -32,6 +86,22 @@ def test_compute(self):
# Assert
assert result == expected_score

def test_compute_with_discretization(self):
"""Test the ``compute`` method with continuous columns."""
# Setup
real_data = pd.DataFrame({'col1': [1.0, 2.4, 2.6, 0.8], 'col2': [1, 2, 3, 4]})
synthetic_data = pd.DataFrame({'col1': [1.0, 1.8, 2.6, 1.0], 'col2': [2, 3, 7, -10]})
expected_score = 0.25

# Run
metric = ContingencySimilarity()
result = metric.compute(
real_data, synthetic_data, continuous_column_names=['col2'], num_discrete_bins=4
)

# Assert
assert result == expected_score

@patch('sdmetrics.column_pairs.statistical.contingency_similarity.ColumnPairsMetric.normalize')
def test_normalize(self, normalize_mock):
"""Test the ``normalize`` method.
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/reports/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ def test_discretize_table_data():

# Assert
expected_real = pd.DataFrame({
'col1': [1, 6, 11],
'col1': [1, 6, 10],
'col2': ['a', 'b', 'c'],
'col3': [2, 1, 11],
'col3': [2, 1, 10],
'col4': [True, False, True],
'col5': [10, 1, 11],
'col5': [10, 1, 10],
})
expected_synth = pd.DataFrame({
'col1': [11, 1, 11],
'col1': [10, 1, 10],
'col2': ['c', 'a', 'c'],
'col3': [11, 0, 5],
'col3': [10, 1, 5],
'col4': [False, False, True],
'col5': [10, 5, 11],
'col5': [10, 5, 10],
})

pd.testing.assert_frame_equal(discretized_real, expected_real)
Expand Down Expand Up @@ -193,18 +193,18 @@ def test_discretize_table_data_new_metadata():

# Assert
expected_real = pd.DataFrame({
'col1': [1, 6, 11],
'col1': [1, 6, 10],
'col2': ['a', 'b', 'c'],
'col3': [2, 1, 11],
'col3': [2, 1, 10],
'col4': [True, False, True],
'col5': [10, 1, 11],
'col5': [10, 1, 10],
})
expected_synth = pd.DataFrame({
'col1': [11, 1, 11],
'col1': [10, 1, 10],
'col2': ['c', 'a', 'c'],
'col3': [11, 0, 5],
'col3': [10, 1, 5],
'col4': [False, False, True],
'col5': [10, 5, 11],
'col5': [10, 5, 10],
})

pd.testing.assert_frame_equal(discretized_real, expected_real)
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sdmetrics.utils import (
HyperTransformer,
discretize_column,
get_alternate_keys,
get_cardinality_distribution,
get_columns_from_metadata,
Expand Down Expand Up @@ -54,6 +55,21 @@ def test_get_missing_percentage():
assert percentage_nan == 28.57


def test_discretize_column():
"""Test the ``discretize_column`` method."""
# Setup
real = pd.Series(range(10))
synthetic = pd.Series([-10] + list(range(1, 9)) + [20])
num_bins = 5

# Run
binned_real, binned_synthetic = discretize_column(real, synthetic, num_discrete_bins=num_bins)

# Assert
np.testing.assert_array_equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5], binned_real)
np.testing.assert_array_equal([1, 1, 2, 2, 3, 3, 4, 4, 5, 5], binned_synthetic)


def test_get_columns_from_metadata():
"""Test the ``get_columns_from_metadata`` method with current metadata format.

Expand Down
Loading