-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
316 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Table Format metric.""" | ||
from sdmetrics.goal import Goal | ||
from sdmetrics.single_table.base import SingleTableMetric | ||
|
||
|
||
class TableFormat(SingleTableMetric): | ||
"""TableFormat Single Table metric. | ||
This metric computes whether the names and data types of each column are | ||
the same in the real and synthetic data. | ||
Attributes: | ||
name (str): | ||
Name to use when reports about this metric are printed. | ||
goal (sdmetrics.goal.Goal): | ||
The goal of this metric. | ||
min_value (Union[float, tuple[float]]): | ||
Minimum value or values that this metric can take. | ||
max_value (Union[float, tuple[float]]): | ||
Maximum value or values that this metric can take. | ||
""" | ||
|
||
name = 'TableFormat' | ||
goal = Goal.MAXIMIZE | ||
min_value = 0 | ||
max_value = 1 | ||
|
||
@classmethod | ||
def compute_breakdown(cls, real_data, synthetic_data, ignore_dtype_columns=None): | ||
"""Compute the score breakdown of the table format metric. | ||
Args: | ||
real_data (pandas.DataFrame): | ||
The real data. | ||
synthetic_data (pandas.DataFrame): | ||
The synthetic data. | ||
ignore_dtype_columns (list[str]): | ||
List of column names to ignore when comparing data types. | ||
Defaults to ``None``. | ||
""" | ||
ignore_dtype_columns = ignore_dtype_columns or [] | ||
missing_columns_in_synthetic = set(real_data.columns) - set(synthetic_data.columns) | ||
invalid_names = [] | ||
invalid_sdtypes = [] | ||
for column in synthetic_data.columns: | ||
if column not in real_data.columns: | ||
invalid_names.append(column) | ||
continue | ||
|
||
if column in ignore_dtype_columns: | ||
continue | ||
|
||
if synthetic_data[column].dtype != real_data[column].dtype: | ||
invalid_sdtypes.append(column) | ||
|
||
proportion_correct_columns = 1 - len(missing_columns_in_synthetic) / len(real_data.columns) | ||
proportion_valid_names = 1 - len(invalid_names) / len(synthetic_data.columns) | ||
proportion_valid_sdtypes = 1 - len(invalid_sdtypes) / len(synthetic_data.columns) | ||
|
||
score = proportion_correct_columns * proportion_valid_names * proportion_valid_sdtypes | ||
return { | ||
key: value | ||
for key, value in { | ||
'score': score, | ||
'missing columns in synthetic data': list(missing_columns_in_synthetic), | ||
'invalid column names': invalid_names, | ||
'invalid column data types': invalid_sdtypes | ||
}.items() | ||
if value | ||
} | ||
|
||
@classmethod | ||
def compute(cls, real_data, synthetic_data, ignore_dtype_columns=None): | ||
"""Compute the table format metric score. | ||
Args: | ||
real_data (pandas.DataFrame): | ||
The real data. | ||
synthetic_data (pandas.DataFrame): | ||
The synthetic data. | ||
ignore_dtype_columns (list[str]): | ||
List of column names to ignore when comparing data types. | ||
Defaults to ``None``. | ||
Returns: | ||
float: | ||
The metric score. | ||
""" | ||
return cls.compute_breakdown(real_data, synthetic_data, ignore_dtype_columns)['score'] | ||
|
||
@classmethod | ||
def normalize(cls, raw_score): | ||
"""Return the `raw_score` as is, since it is already normalized. | ||
Args: | ||
raw_score (float): | ||
The value of the metric from `compute`. | ||
Returns: | ||
float: | ||
The normalized value of the metric | ||
""" | ||
return super().normalize(raw_score) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from unittest.mock import patch | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
from sdmetrics.single_table import TableFormat | ||
|
||
|
||
@pytest.fixture() | ||
def real_data(): | ||
return pd.DataFrame({ | ||
'col_1': [1, 2, 3, 4, 5], | ||
'col_2': ['A', 'B', 'C', 'B', 'A'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': pd.to_datetime([ | ||
'2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
]), | ||
'col_5': [1.0, 2.0, 3.0, 4.0, 5.0] | ||
}) | ||
|
||
|
||
class TestTableFormat: | ||
|
||
def test_compute_breakdown(self, real_data): | ||
"""Test the ``compute_breakdown`` method.""" | ||
# Setup | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [3, 2, 1, 4, 5], | ||
'col_2': ['A', 'B', 'C', 'D', 'E'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': pd.to_datetime([ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
]), | ||
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0] | ||
}) | ||
|
||
metric = TableFormat() | ||
|
||
# Run | ||
result = metric.compute_breakdown(real_data, synthetic_data) | ||
|
||
# Assert | ||
expected_result = { | ||
'score': 1.0, | ||
} | ||
assert result == expected_result | ||
|
||
def test_compute_breakdown_with_missing_columns(self, real_data): | ||
"""Test the ``compute_breakdown`` method with missing columns.""" | ||
# Setup | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [3, 2, 1, 4, 5], | ||
'col_2': ['A', 'B', 'C', 'D', 'E'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': pd.to_datetime([ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
]), | ||
}) | ||
|
||
metric = TableFormat() | ||
|
||
# Run | ||
result = metric.compute_breakdown(real_data, synthetic_data) | ||
|
||
# Assert | ||
expected_result = { | ||
'score': 0.8, | ||
'missing columns in synthetic data': ['col_5'] | ||
} | ||
assert result == expected_result | ||
|
||
def test_compute_breakdown_with_invalid_names(self, real_data): | ||
"""Test the ``compute_breakdown`` method with invalid names.""" | ||
# Setup | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [3, 2, 1, 4, 5], | ||
'col_2': ['A', 'B', 'C', 'D', 'E'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': pd.to_datetime([ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
]), | ||
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], | ||
'col_6': [4.0, 2.0, 3.0, 4.0, 5.0], | ||
}) | ||
|
||
metric = TableFormat() | ||
|
||
# Run | ||
result = metric.compute_breakdown(real_data, synthetic_data) | ||
|
||
# Assert | ||
expected_result = { | ||
'score': 0.8333333333333334, | ||
'invalid column names': ['col_6'] | ||
} | ||
assert result == expected_result | ||
|
||
def test_compute_breakdown_with_invalid_dtypes(self, real_data): | ||
"""Test the ``compute_breakdown`` method with invalid dtypes.""" | ||
# Setup | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [3.0, 2.0, 1.0, 4.0, 5.0], | ||
'col_2': ['A', 'B', 'C', 'D', 'E'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': [ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
], | ||
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], | ||
}) | ||
|
||
metric = TableFormat() | ||
|
||
# Run | ||
result = metric.compute_breakdown(real_data, synthetic_data) | ||
|
||
# Assert | ||
expected_result = { | ||
'score': 0.6, | ||
'invalid column data types': ['col_1', 'col_4'] | ||
} | ||
assert result == expected_result | ||
|
||
def test_compute_breakdown_ignore_dtype_columns(self, real_data): | ||
"""Test the ``compute_breakdown`` method when ignore_dtype_columns is set.""" | ||
# Setup | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [3.0, 2.0, 1.0, 4.0, 5.0], | ||
'col_2': ['A', 'B', 'C', 'D', 'E'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': [ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
], | ||
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], | ||
}) | ||
|
||
metric = TableFormat() | ||
|
||
# Run | ||
result = metric.compute_breakdown( | ||
real_data, synthetic_data, ignore_dtype_columns=['col_4'] | ||
) | ||
|
||
# Assert | ||
expected_result = { | ||
'score': 0.8, | ||
'invalid column data types': ['col_1'] | ||
} | ||
assert result == expected_result | ||
|
||
def test_compute_breakdown_multiple_error(self, real_data): | ||
"""Test the ``compute_breakdown`` method with the different failure modes.""" | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [1, 2, 1, 4, 5], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': [ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
], | ||
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0], | ||
'col_6': [4.0, 2.0, 3.0, 4.0, 5.0], | ||
}) | ||
|
||
metric = TableFormat() | ||
|
||
# Run | ||
result = metric.compute_breakdown(real_data, synthetic_data) | ||
|
||
# Assert | ||
expected_result = { | ||
'score': 0.5120000000000001, | ||
'missing columns in synthetic data': ['col_2'], | ||
'invalid column names': ['col_6'], | ||
'invalid column data types': ['col_4'] | ||
} | ||
assert result == expected_result | ||
|
||
@patch('sdmetrics.single_table.table_format.TableFormat.compute_breakdown') | ||
def test_compute(self, compute_breakdown_mock, real_data): | ||
"""Test the ``compute`` method.""" | ||
# Setup | ||
synthetic_data = pd.DataFrame({ | ||
'col_1': [3, 2, 1, 4, 5], | ||
'col_2': ['A', 'B', 'C', 'D', 'E'], | ||
'col_3': [True, False, True, False, True], | ||
'col_4': pd.to_datetime([ | ||
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05' | ||
]), | ||
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0] | ||
}) | ||
metric = TableFormat() | ||
compute_breakdown_mock.return_value = {'score': 0.6} | ||
|
||
# Run | ||
result = metric.compute(real_data, synthetic_data) | ||
|
||
# Assert | ||
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None) | ||
assert result == 0.6 | ||
|
||
@patch('sdmetrics.single_table.table_format.SingleTableMetric.normalize') | ||
def test_normalize(self, normalize_mock): | ||
"""Test the ``normalize`` method.""" | ||
# Setup | ||
metric = TableFormat() | ||
raw_score = 0.9 | ||
|
||
# Run | ||
result = metric.normalize(raw_score) | ||
|
||
# Assert | ||
normalize_mock.assert_called_once_with(raw_score) | ||
assert result == normalize_mock.return_value |