-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
495 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
sdmetrics/reports/multi_table/_properties/data_validity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Data validity property for multi-table.""" | ||
from sdmetrics.reports.multi_table._properties import BaseMultiTableProperty | ||
from sdmetrics.reports.single_table._properties import DataValidity as SingleTableDataValidity | ||
|
||
|
||
class DataValidity(BaseMultiTableProperty): | ||
"""Data Validitys property class for multi-table. | ||
This property computes, at base, whether each column contains valid data. | ||
The metric is based on the type data in each column. | ||
A metric score is computed column-wise and the final score is the average over all columns. | ||
The BoundaryAdherence metric is used for numerical and datetime columns, the CategoryAdherence | ||
is used for categorical and boolean columns and the KeyUniqueness for primary and | ||
alternate keys. The other column types are ignored by this property. | ||
""" | ||
|
||
_single_table_property = SingleTableDataValidity | ||
_num_iteration_case = 'column' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
sdmetrics/reports/single_table/_properties/data_validity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import plotly.express as px | ||
|
||
from sdmetrics.reports.single_table._properties import BaseSingleTableProperty | ||
from sdmetrics.reports.utils import PlotConfig | ||
from sdmetrics.single_column import BoundaryAdherence, CategoryAdherence, KeyUniqueness | ||
|
||
|
||
class DataValidity(BaseSingleTableProperty): | ||
"""Data Validity property class for single table. | ||
This property computes, at base, whether each column contains valid data. | ||
The metric is based on the type data in each column. | ||
The BoundaryAdherence metric is used for numerical and datetime columns, the CategoryAdherence | ||
is used for categorical and boolean columns and the KeyUniqueness for primary | ||
and alternate keys. The other column types are ignored by this property. | ||
""" | ||
|
||
_num_iteration_case = 'column' | ||
_sdtype_to_metric = { | ||
'numerical': BoundaryAdherence, | ||
'datetime': BoundaryAdherence, | ||
'categorical': CategoryAdherence, | ||
'boolean': CategoryAdherence, | ||
'id': KeyUniqueness, | ||
} | ||
|
||
def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=None): | ||
"""Generate the _details dataframe for the data validity property. | ||
Args: | ||
real_data (pandas.DataFrame): | ||
The real data | ||
synthetic_data (pandas.DataFrame): | ||
The synthetic data | ||
metadata (dict): | ||
The metadata of the table | ||
progress_bar (tqdm.tqdm or None): | ||
The progress bar to use. Defaults to None. | ||
""" | ||
column_names, metric_names, scores = [], [], [] | ||
error_messages = [] | ||
primary_key = metadata.get('primary_key') | ||
alternate_keys = metadata.get('alternate_keys', []) | ||
for column_name in metadata['columns']: | ||
sdtype = metadata['columns'][column_name]['sdtype'] | ||
primary_key_match = column_name == primary_key | ||
alternate_key_match = column_name in alternate_keys | ||
is_unique = primary_key_match or alternate_key_match | ||
|
||
try: | ||
if sdtype not in self._sdtype_to_metric and not is_unique: | ||
continue | ||
|
||
metric = self._sdtype_to_metric.get(sdtype, KeyUniqueness) | ||
column_score = metric.compute( | ||
real_data[column_name], synthetic_data[column_name] | ||
) | ||
error_message = None | ||
|
||
except Exception as e: | ||
column_score = np.nan | ||
error_message = f'{type(e).__name__}: {e}' | ||
finally: | ||
if progress_bar: | ||
progress_bar.update() | ||
|
||
column_names.append(column_name) | ||
metric_names.append(metric.__name__) | ||
scores.append(column_score) | ||
error_messages.append(error_message) | ||
|
||
result = pd.DataFrame({ | ||
'Column': column_names, | ||
'Metric': metric_names, | ||
'Score': scores, | ||
'Error': error_messages, | ||
}) | ||
|
||
if result['Error'].isna().all(): | ||
result = result.drop('Error', axis=1) | ||
|
||
return result | ||
|
||
def get_visualization(self): | ||
"""Create a plot to show the data validity scores. | ||
Returns: | ||
plotly.graph_objects._figure.Figure | ||
""" | ||
average_score = round(self._compute_average(), 2) | ||
|
||
fig = px.bar( | ||
data_frame=self.details, | ||
x='Column', | ||
y='Score', | ||
title=f'Data Diagnostic: Data Validity (Average Score={average_score})', | ||
category_orders={'group': list(self.details['Column'])}, | ||
color='Metric', | ||
color_discrete_map={ | ||
'BoundaryAdherence': PlotConfig.DATACEBO_DARK, | ||
'CategoryAdherence': PlotConfig.DATACEBO_BLUE, | ||
'KeyUniqueness': PlotConfig.DATACEBO_GREEN | ||
|
||
}, | ||
pattern_shape='Metric', | ||
pattern_shape_sequence=['', '/', '.'], | ||
hover_name='Column', | ||
hover_data={ | ||
'Column': False, | ||
'Metric': True, | ||
'Score': True, | ||
}, | ||
) | ||
|
||
fig.update_yaxes(range=[0, 1]) | ||
|
||
fig.update_layout( | ||
xaxis_categoryorder='total ascending', | ||
plot_bgcolor=PlotConfig.BACKGROUND_COLOR, | ||
margin={'t': 150}, | ||
font={'size': PlotConfig.FONT_SIZE}, | ||
) | ||
|
||
return fig |
39 changes: 39 additions & 0 deletions
39
tests/integration/reports/multi_table/_properties/test_data_validity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from unittest.mock import Mock | ||
|
||
from tqdm import tqdm | ||
|
||
from sdmetrics.demos import load_demo | ||
from sdmetrics.reports.multi_table._properties import DataValidity | ||
|
||
|
||
class TestDataValidity: | ||
|
||
def test_end_to_end(self): | ||
"""Test the ``DataValidity`` multi-table property end to end.""" | ||
# Setup | ||
real_data, synthetic_data, metadata = load_demo(modality='multi_table') | ||
column_shapes = DataValidity() | ||
|
||
# Run | ||
result = column_shapes.get_score(real_data, synthetic_data, metadata) | ||
|
||
# Assert | ||
assert result == 0.9444444444444445 | ||
|
||
def test_with_progress_bar(self): | ||
"""Test that the progress bar is correctly updated.""" | ||
# Setup | ||
real_data, synthetic_data, metadata = load_demo(modality='multi_table') | ||
column_shapes = DataValidity() | ||
num_columns = sum(len(table['columns']) for table in metadata['tables'].values()) | ||
|
||
progress_bar = tqdm(total=num_columns) | ||
mock_update = Mock() | ||
progress_bar.update = mock_update | ||
|
||
# Run | ||
result = column_shapes.get_score(real_data, synthetic_data, metadata, progress_bar) | ||
|
||
# Assert | ||
assert result == 0.9444444444444445 | ||
assert mock_update.call_count == num_columns |
73 changes: 73 additions & 0 deletions
73
tests/integration/reports/single_table/_properties/test_data_validity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pandas as pd | ||
|
||
from sdmetrics.demos import load_demo | ||
from sdmetrics.reports.single_table._properties import DataValidity | ||
|
||
|
||
class TestDataValidity: | ||
|
||
def test_get_score(self): | ||
"""Test the ``get_score`` method""" | ||
# Setup | ||
real_data, synthetic_data, metadata = load_demo('single_table') | ||
|
||
# Run | ||
data_validity_property = DataValidity() | ||
score = data_validity_property.get_score(real_data, synthetic_data, metadata) | ||
|
||
# Assert | ||
expected_details_dict = { | ||
'Column': [ | ||
'start_date', 'end_date', 'salary', 'duration', 'student_id', | ||
'high_perc', 'high_spec', 'mba_spec', 'second_perc', 'gender', | ||
'degree_perc', 'placed', 'experience_years', 'employability_perc', | ||
'mba_perc', 'work_experience', 'degree_type' | ||
], | ||
'Metric': [ | ||
'BoundaryAdherence', 'BoundaryAdherence', 'BoundaryAdherence', 'BoundaryAdherence', | ||
'KeyUniqueness', 'BoundaryAdherence', 'CategoryAdherence', 'CategoryAdherence', | ||
'BoundaryAdherence', 'CategoryAdherence', 'BoundaryAdherence', 'CategoryAdherence', | ||
'BoundaryAdherence', 'BoundaryAdherence', 'BoundaryAdherence', 'CategoryAdherence', | ||
'CategoryAdherence' | ||
], | ||
'Score': [ | ||
0.8503937007874016, 0.8615384615384616, 0.9444444444444444, | ||
1.0, 1.0, 0.8651162790697674, 1.0, 1.0, 0.9255813953488372, | ||
1.0, 0.9441860465116279, 1.0, 1.0, 0.8883720930232558, | ||
0.8930232558139535, 1.0, 1.0 | ||
] | ||
} | ||
expected_details = pd.DataFrame(expected_details_dict) | ||
pd.testing.assert_frame_equal(data_validity_property.details, expected_details) | ||
assert score == 0.9513326868551618 | ||
|
||
def test_get_score_errors(self): | ||
"""Test the ``get_score`` method when the metrics are raising errors for some columns.""" | ||
# Setup | ||
real_data, synthetic_data, metadata = load_demo('single_table') | ||
|
||
real_data['start_date'].iloc[0] = 0 | ||
real_data['employability_perc'].iloc[2] = 'a' | ||
|
||
# Run | ||
data_validity_property = DataValidity() | ||
|
||
expected_message_1 = ( | ||
"TypeError: '<=' not supported between instances of 'int' and 'Timestamp'" | ||
) | ||
expected_message_2 = ( | ||
"TypeError: '<=' not supported between instances of 'float' and 'str'" | ||
) | ||
|
||
score = data_validity_property.get_score(real_data, synthetic_data, metadata) | ||
|
||
# Assert | ||
|
||
details = data_validity_property.details | ||
details_nan = details.loc[pd.isna(details['Score'])] | ||
column_names_nan = details_nan['Column'].tolist() | ||
error_messages = details_nan['Error'].tolist() | ||
assert column_names_nan == ['start_date', 'employability_perc'] | ||
assert error_messages[0] == expected_message_1 | ||
assert error_messages[1] == expected_message_2 | ||
assert score == 0.9622593255151395 |
14 changes: 14 additions & 0 deletions
14
tests/unit/reports/multi_table/_properties/test_validity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Test Data Validity multi-table class.""" | ||
from sdmetrics.reports.multi_table._properties import DataValidity | ||
from sdmetrics.reports.single_table._properties import DataValidity as SingleTableDataValidity | ||
|
||
|
||
def test__init__(): | ||
"""Test the ``__init__`` method.""" | ||
# Setup | ||
column_shapes = DataValidity() | ||
|
||
# Assert | ||
assert column_shapes._properties == {} | ||
assert column_shapes._single_table_property == SingleTableDataValidity | ||
assert column_shapes._num_iteration_case == 'column' |
Oops, something went wrong.