Skip to content

Commit

Permalink
Filter out keys that cannot be statistically modeled (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Nov 22, 2023
1 parent 87790f2 commit fa50cec
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 8 deletions.
37 changes: 29 additions & 8 deletions sdmetrics/single_table/detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
from sdmetrics.utils import HyperTransformer
from sdmetrics.utils import HyperTransformer, get_alternate_keys

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,6 +43,32 @@ def _fit_predict(X_train, y_train, X_test):
"""Fit a classifier and then use it to predict."""
raise NotImplementedError()

@staticmethod
def _drop_non_compute_columns(real_data, synthetic_data, metadata):
"""Drop all columns that cannot be statistically modeled."""
transformed_real_data = real_data
transformed_synthetic_data = synthetic_data

if metadata is not None:
drop_columns = []
drop_columns.extend(get_alternate_keys(metadata))
for column in metadata.get('columns', []):
if ('primary_key' in metadata and
(column == metadata['primary_key'] or
column in metadata['primary_key'])):
drop_columns.append(column)

column_info = metadata['columns'].get(column, {})
sdtype = column_info.get('sdtype')
pii = column_info.get('pii')
if sdtype not in ['numerical', 'datetime', 'categorical'] or pii:
drop_columns.append(column)

if drop_columns:
transformed_real_data = real_data.drop(drop_columns, axis=1)
transformed_synthetic_data = synthetic_data.drop(drop_columns, axis=1)
return transformed_real_data, transformed_synthetic_data

@classmethod
def compute(cls, real_data, synthetic_data, metadata=None):
"""Compute this metric.
Expand All @@ -68,13 +94,8 @@ def compute(cls, real_data, synthetic_data, metadata=None):
real_data, synthetic_data, metadata = cls._validate_inputs(
real_data, synthetic_data, metadata)

if metadata is not None and 'primary_key' in metadata:
transformed_real_data = real_data.drop(metadata['primary_key'], axis=1)
transformed_synthetic_data = synthetic_data.drop(metadata['primary_key'], axis=1)

else:
transformed_real_data = real_data
transformed_synthetic_data = synthetic_data
transformed_real_data, transformed_synthetic_data = cls._drop_non_compute_columns(
real_data, synthetic_data, metadata)

ht = HyperTransformer()
transformed_real_data = ht.fit_transform(transformed_real_data).to_numpy()
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/single_table/detection/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,80 @@ def test_primary_key_detection_metrics(self, fit_transform_mock, transform_mock)
transform_mock.assert_called_with(expected_return_synthetic)
assert expected_return_real == call_1
assert expected_return_synthetic == call_2

@patch('sdmetrics.utils.HyperTransformer.transform')
@patch('sdmetrics.utils.HyperTransformer.fit_transform')
def test_ignore_keys_detection_metrics(self, fit_transform_mock, transform_mock):
"""This test checks that ``primary_key`` columns of dataset are ignored.
Ensure that ``primary_keys`` are ignored for Detection metrics expect that the match
is made correctly.
"""

# Setup
real_data = pd.DataFrame({
'ID_1': [1, 2, 1, 3, 4],
'col1': [43.0, 47.5, 34.2, 30.3, 39.1],
'col2': [1.0, 2.0, 3.0, 4.0, 5.0],
'ID_2': ['aa', 'bb', 'cc', 'dd', 'bb'],
'col3': [5, 6, 7, 8, 9],
'ID_3': ['a', 'b', 'c', 'd', 'e'],
'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'],
'col4': [1, 3, 9, 2, 1],
'col5': [10, 20, 30, 40, 50]
})
synthetic_data = pd.DataFrame({
'ID_1': [1, 3, 4, 2, 2],
'col1': [23.0, 47.1, 44.9, 31.3, 9.7],
'col2': [11.0, 22.0, 33.0, 44.0, 55.0],
'ID_2': ['aa', 'bb', 'cc', 'dd', 'ee'],
'col3': [55, 66, 77, 88, 99],
'ID_3': ['a', 'b', 'e', 'd', 'c'],
'blob': ['Hello world!', 'Hello world!', 'This is SDV', 'This is SDV', 'Hello world!'],
'col4': [4, 1, 3, 1, 9],
'col5': [10, 20, 30, 40, 50]
})
metadata = {
'columns': {
'ID_1': {'sdtype': 'numerical'},
'col1': {'sdtype': 'numerical', 'pii': True},
'col2': {'sdtype': 'numerical'},
'ID_2': {'sdtype': 'categorical'},
'col3': {'sdtype': 'numerical'},
'ID_3': {'sdtype': 'id'},
'blob': {'sdtype': 'text'},
'col4': {'sdtype': 'numerical', 'pii': False},
'col5': {'sdtype': 'numerical'}
},
'primary_key': {'ID_1', 'ID_2'},
'alternate_keys': ['col5']
}

expected_real_dataframe = pd.DataFrame({
'col2': [1.0, 2.0, 3.0, 4.0, 5.0],
'col3': [5, 6, 7, 8, 9],
'col4': [1, 3, 9, 2, 1]
})
expected_synthetic_dataframe = pd.DataFrame({
'col2': [11.0, 22.0, 33.0, 44.0, 55.0],
'col3': [55, 66, 77, 88, 99],
'col4': [4, 1, 3, 1, 9]
})

expected_return_real = DataFrameMatcher(expected_real_dataframe)
expected_return_synthetic = DataFrameMatcher(expected_synthetic_dataframe)
fit_transform_mock.return_value = expected_real_dataframe
transform_mock.return_value = expected_synthetic_dataframe

# Run
LogisticDetection().compute(real_data, synthetic_data, metadata)

# Assert

# check that ``fit_transform`` and ``transform`` received the good argument.
call_1 = pd.DataFrame(fit_transform_mock.call_args_list[0][0][0])
call_2 = pd.DataFrame(transform_mock.call_args_list[0][0][0])

transform_mock.assert_called_with(expected_return_synthetic)
assert expected_return_real == call_1
assert expected_return_synthetic == call_2

0 comments on commit fa50cec

Please sign in to comment.