Skip to content

Commit

Permalink
def + test
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 23, 2023
1 parent 99cb1e4 commit 32f9cca
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sdmetrics/single_table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sdmetrics.single_table.privacy.numerical_sklearn import (
NumericalLR, NumericalMLP, NumericalSVR)
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
from sdmetrics.single_table.table_format import TableFormat

__all__ = [
'bayesian_network',
Expand Down Expand Up @@ -90,4 +91,5 @@
'TVComplement',
'RangeCoverage',
'NewRowSynthesis',
'TableFormat',
]
103 changes: 103 additions & 0 deletions sdmetrics/single_table/table_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Table Format metric."""
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric


class TableFormat(SingleTableMetric):
"""TableFormat Single Table metric.
This metric computes whether the names and data types of each column are
the same in the real and synthetic data.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'TableFormat'
goal = Goal.MAXIMIZE
min_value = 0
max_value = 1

@classmethod
def compute_breakdown(cls, real_data, synthetic_data, ignore_dtype_columns=None):
"""Compute the score breakdown of the table format metric.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
ignore_dtype_columns (list[str]):
List of column names to ignore when comparing data types.
Defaults to ``None``.
"""
ignore_dtype_columns = ignore_dtype_columns or []
missing_columns_in_synthetic = set(real_data.columns) - set(synthetic_data.columns)
invalid_names = []
invalid_sdtypes = []
for column in synthetic_data.columns:
if column not in real_data.columns:
invalid_names.append(column)
continue

if column in ignore_dtype_columns:
continue

if synthetic_data[column].dtype != real_data[column].dtype:
invalid_sdtypes.append(column)

proportion_correct_columns = 1 - len(missing_columns_in_synthetic) / len(real_data.columns)
proportion_valid_names = 1 - len(invalid_names) / len(synthetic_data.columns)
proportion_valid_sdtypes = 1 - len(invalid_sdtypes) / len(synthetic_data.columns)

score = proportion_correct_columns * proportion_valid_names * proportion_valid_sdtypes
return {
key: value
for key, value in {
'score': score,
'missing columns in synthetic data': list(missing_columns_in_synthetic),
'invalid column names': invalid_names,
'invalid column data types': invalid_sdtypes
}.items()
if value
}

@classmethod
def compute(cls, real_data, synthetic_data, ignore_dtype_columns=None):
"""Compute the table format metric score.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
ignore_dtype_columns (list[str]):
List of column names to ignore when comparing data types.
Defaults to ``None``.
Returns:
float:
The metric score.
"""
return cls.compute_breakdown(real_data, synthetic_data, ignore_dtype_columns)['score']

@classmethod
def normalize(cls, raw_score):
"""Return the `raw_score` as is, since it is already normalized.
Args:
raw_score (float):
The value of the metric from `compute`.
Returns:
float:
The normalized value of the metric
"""
return super().normalize(raw_score)
211 changes: 211 additions & 0 deletions tests/unit/single_table/test_table_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from unittest.mock import patch

import pandas as pd
import pytest

from sdmetrics.single_table import TableFormat


@pytest.fixture()
def real_data():
return pd.DataFrame({
'col_1': [1, 2, 3, 4, 5],
'col_2': ['A', 'B', 'C', 'B', 'A'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [1.0, 2.0, 3.0, 4.0, 5.0]
})


class TestTableFormat:

def test_compute_breakdown(self, real_data):
"""Test the ``compute_breakdown`` method."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0]
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {
'score': 1.0,
}
assert result == expected_result

def test_compute_breakdown_with_missing_columns(self, real_data):
"""Test the ``compute_breakdown`` method with missing columns."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {
'score': 0.8,
'missing columns in synthetic data': ['col_5']
}
assert result == expected_result

def test_compute_breakdown_with_invalid_names(self, real_data):
"""Test the ``compute_breakdown`` method with invalid names."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
'col_6': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {
'score': 0.8333333333333334,
'invalid column names': ['col_6']
}
assert result == expected_result

def test_compute_breakdown_with_invalid_dtypes(self, real_data):
"""Test the ``compute_breakdown`` method with invalid dtypes."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3.0, 2.0, 1.0, 4.0, 5.0],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': [
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
],
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {
'score': 0.6,
'invalid column data types': ['col_1', 'col_4']
}
assert result == expected_result

def test_compute_breakdown_ignore_dtype_columns(self, real_data):
"""Test the ``compute_breakdown`` method when ignore_dtype_columns is set."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3.0, 2.0, 1.0, 4.0, 5.0],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': [
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
],
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(
real_data, synthetic_data, ignore_dtype_columns=['col_4']
)

# Assert
expected_result = {
'score': 0.8,
'invalid column data types': ['col_1']
}
assert result == expected_result

def test_compute_breakdown_multiple_error(self, real_data):
"""Test the ``compute_breakdown`` method with the different failure modes."""
synthetic_data = pd.DataFrame({
'col_1': [1, 2, 1, 4, 5],
'col_3': [True, False, True, False, True],
'col_4': [
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
],
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0],
'col_6': [4.0, 2.0, 3.0, 4.0, 5.0],
})

metric = TableFormat()

# Run
result = metric.compute_breakdown(real_data, synthetic_data)

# Assert
expected_result = {
'score': 0.5120000000000001,
'missing columns in synthetic data': ['col_2'],
'invalid column names': ['col_6'],
'invalid column data types': ['col_4']
}
assert result == expected_result

@patch('sdmetrics.single_table.table_format.TableFormat.compute_breakdown')
def test_compute(self, compute_breakdown_mock, real_data):
"""Test the ``compute`` method."""
# Setup
synthetic_data = pd.DataFrame({
'col_1': [3, 2, 1, 4, 5],
'col_2': ['A', 'B', 'C', 'D', 'E'],
'col_3': [True, False, True, False, True],
'col_4': pd.to_datetime([
'2020-01-11', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'
]),
'col_5': [4.0, 2.0, 3.0, 4.0, 5.0]
})
metric = TableFormat()
compute_breakdown_mock.return_value = {'score': 0.6}

# Run
result = metric.compute(real_data, synthetic_data)

# Assert
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None)
assert result == 0.6

@patch('sdmetrics.single_table.table_format.SingleTableMetric.normalize')
def test_normalize(self, normalize_mock):
"""Test the ``normalize`` method."""
# Setup
metric = TableFormat()
raw_score = 0.9

# Run
result = metric.normalize(raw_score)

# Assert
normalize_mock.assert_called_once_with(raw_score)
assert result == normalize_mock.return_value

0 comments on commit 32f9cca

Please sign in to comment.