Skip to content

Commit

Permalink
ValueError in DiagnosticReport if synthetic data does not match met…
Browse files Browse the repository at this point in the history
…adata (#524)
  • Loading branch information
R-Palazzo committed Nov 27, 2023
1 parent 94cf38c commit e611cd9
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 4 deletions.
3 changes: 0 additions & 3 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def _validate(self, real_data, synthetic_data, metadata):
self._validate_data_format(real_data, synthetic_data)
self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

def _handle_results(self, verbose):
raise NotImplementedError

@staticmethod
def convert_datetimes(real_data, synthetic_data, metadata):
"""Try to convert all datetime columns to datetime dtype.
Expand Down
20 changes: 19 additions & 1 deletion sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,32 @@ def _validate_relationships(self, real_data, synthetic_data, metadata):

def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
"""Validate that the metadata matches the data."""
self.table_names = list(metadata['tables'].keys())
for table in self.table_names:
super()._validate_metadata_matches_data(
real_data[table], synthetic_data[table], metadata['tables'][table]
)

self._validate_relationships(real_data, synthetic_data, metadata)

def generate(self, real_data, synthetic_data, metadata, verbose=True):
"""Generate report.
This method generates the report by iterating through each property and calculating
the score for each property.
Args:
real_data (pandas.DataFrame):
The real data.
synthetic_data (pandas.DataFrame):
The synthetic data.
metadata (dict):
The metadata, which contains each column's data type as well as relationships.
verbose (bool):
Whether or not to print report summary and progress.
"""
self.table_names = list(metadata['tables'].keys())
return super().generate(real_data, synthetic_data, metadata, verbose)

def _check_table_names(self, table_name):
if table_name not in self.table_names:
raise ValueError(f"Unknown table ('{table_name}'). Must be one of {self.table_names}.")
Expand Down
3 changes: 3 additions & 0 deletions sdmetrics/reports/multi_table/diagnostic_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ def __init__(self):
'Data Structure': Structure(),
'Relationship Validity': RelationshipValidity()
}

def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
self._validate_relationships(real_data, synthetic_data, metadata)
3 changes: 3 additions & 0 deletions sdmetrics/reports/single_table/diagnostic_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ def __init__(self):
'Data Validity': DataValidity(),
'Data Structure': Structure(),
}

def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
return
37 changes: 37 additions & 0 deletions tests/integration/reports/single_table/test_diagnostic_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,40 @@ def test_get_details_with_errors(self):
report.get_details('Data Validity'),
expected_details
)

def test_report_runs_with_mismatch_data_metadata(self):
"""Test that the report runs with mismatched data and metadata."""
# Setup
data = pd.DataFrame({
'id': [0, 1, 2],
'val1': ['a', 'a', 'b'],
'val2': [0.1, 2.4, 5.7]
})
synthetic_data = pd.DataFrame({
'id': [1, 2, 3],
'extra_col': ['x', 'y', 'z'],
'val1': ['c', 'd', 'd']
})

metadata = {
'columns': {
'id': {'sdtype': 'id'},
'val1': {'sdtype': 'categorical'},
'val2': {'sdtype': 'numerical'}
},
'primary_key': 'id'
}
report = DiagnosticReport()

# Run
report.generate(data, synthetic_data, metadata)

# Assert
expected_properties = pd.DataFrame({
'Property': ['Data Validity', 'Data Structure'],
'Score': [0.5, 0.5]
})
assert report.get_score() == 0.5
pd.testing.assert_frame_equal(
report.get_properties(), expected_properties
)
44 changes: 44 additions & 0 deletions tests/unit/reports/multi_table/test_base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_d

report = BaseMultiTableReport()
report._validate_relationships = mock__validate_relationships
report.table_names = ['Table_1', 'Table_2']

# Run
report._validate_metadata_matches_data(real_data, synthetic_data, metadata)
Expand All @@ -143,6 +144,49 @@ def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_d
mock__validate_metadata_matches_data.assert_has_calls(expected_calls)
report._validate_relationships.assert_called_once_with(real_data, synthetic_data, metadata)

@patch('sdmetrics.reports.base_report.BaseReport.generate')
def test_generate(self, mock_generate):
"""Test the ``generate`` method."""
# Setup
real_data = {
'Table_1': pd.DataFrame({'col1': [1, 2, 3]}),
'Table_2': pd.DataFrame({'col2': [4, 5, 6]}),
}
synthetic_data = {
'Table_1': pd.DataFrame({'col1': [1, 2, 3]}),
'Table_2': pd.DataFrame({'col2': [4, 5, 6]}),
}
real_data = {
'Table_1': pd.DataFrame({'col1': [1, 2, 3]}),
'Table_2': pd.DataFrame({'col2': [4, 5, 6]}),
}
synthetic_data = {
'Table_1': pd.DataFrame({'col1': [1, 2, 3]}),
'Table_2': pd.DataFrame({'col2': [4, 5, 6]}),
}
metadata = {
'tables': {
'Table_1': {
'columns': {
'col1': {},
},
},
'Table_2': {
'columns': {
'col2': {}
},
},
},
}
report = BaseMultiTableReport()

# Run
report.generate(real_data, synthetic_data, metadata)

# Assert
assert report.table_names == ['Table_1', 'Table_2']
mock_generate.assert_called_once_with(real_data, synthetic_data, metadata, True)

def test__check_table_names(self):
"""Test the ``_check_table_names`` method."""
# Setup
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,36 @@ def test__validate(self):
real_data, synthetic_data, metadata
)

def test__validate_with_value_error(self):
"""Test the ``_validate`` method with a ValueError."""
# Setup
base_report = BaseReport()
mock__validate_metadata_matches_data = Mock(
side_effect=ValueError('error message')
)
base_report._validate_metadata_matches_data = mock__validate_metadata_matches_data

real_data = pd.DataFrame({
'column1': [1, 2, 3],
'column2': ['a', 'b', 'c'],
'column3': [4, 5, 6]
})
synthetic_data = pd.DataFrame({
'column1': [1, 2, 3],
'column2': ['a', 'b', 'c'],
'column4': [4, 5, 6]
})
metadata = {
'columns': {
'column1': {'sdtype': 'numerical'},
'column2': {'sdtype': 'categorical'},
}
}

# Run and Assert
with pytest.raises(ValueError, match='error message'):
base_report._validate(real_data, synthetic_data, metadata)

def test_convert_datetimes(self):
"""Test that ``_convert_datetimes`` tries to convert datetime columns."""
# Setup
Expand Down

0 comments on commit e611cd9

Please sign in to comment.