Skip to content

Commit

Permalink
Add metadata validation (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Nov 22, 2023
1 parent 61676fc commit 9e83c54
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 5 deletions.
14 changes: 13 additions & 1 deletion sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,24 @@ def _validate_data_format(self, real_data, synthetic_data):
return

error_message = (
f'Single table report {self.__class__.__name__} expects real and synthetic data to be'
f'Single table {self.__class__.__name__} expects real and synthetic data to be'
' pandas.DataFrame. If your real and synthetic data are dictionaries of tables, '
f'please use the multi-table {self.__class__.__name__} instead.'

)
raise ValueError(error_message)

def _validate_metadata_format(self, metadata):
"""Validate the metadata."""
if not isinstance(metadata, dict):
raise TypeError('The provided metadata is not a dictionary.')

if 'columns' not in metadata:
raise ValueError(
'Single table reports expect metadata to contain a "columns" key with a mapping'
' from column names to column informations.'
)

def _validate(self, real_data, synthetic_data, metadata):
"""Validate the inputs.
Expand All @@ -80,6 +91,7 @@ def _validate(self, real_data, synthetic_data, metadata):
The metadata of the table.
"""
self._validate_data_format(real_data, synthetic_data)
self._validate_metadata_format(metadata)
self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

@staticmethod
Expand Down
20 changes: 18 additions & 2 deletions sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,29 @@ def _validate_data_format(self, real_data, synthetic_data):
return

error_message = (
f'Multi table report {self.__class__.__name__} expects real and synthetic data to be'
f'Multi table {self.__class__.__name__} expects real and synthetic data to be'
' dictionaries of pandas.DataFrame. If your real and synthetic data are pd.DataFrame,'
f' please use the single-table {self.__class__.__name__} instead.'
)

raise ValueError(error_message)

def _validate_metadata_format(self, metadata):
"""Validate the metadata."""
if not isinstance(metadata, dict):
raise TypeError('The provided metadata is not a dictionary.')

if 'tables' not in metadata:
raise ValueError(
'Multi table reports expect metadata to contain a "tables" key with a mapping'
' from table names to metadata for each table.'
)
for table_name, table_metadata in metadata['tables'].items():
if 'columns' not in table_metadata:
raise ValueError(
f'The metadata for table "{table_name}" is missing a "columns" key.'
)

def _validate_relationships(self, real_data, synthetic_data, metadata):
"""Validate that the relationships are valid."""
for rel in metadata.get('relationships', []):
Expand Down Expand Up @@ -83,7 +99,7 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
verbose (bool):
Whether or not to print report summary and progress.
"""
self.table_names = list(metadata['tables'].keys())
self.table_names = list(metadata.get('tables', {}).keys())
return super().generate(real_data, synthetic_data, metadata, verbose)

def _check_table_names(self, table_name):
Expand Down
53 changes: 52 additions & 1 deletion tests/unit/reports/multi_table/test_base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,64 @@ def test__validate_data_format(self):

# Run and Assert
expected_message = (
'Multi table report BaseMultiTableReport expects real and synthetic data to be '
'Multi table BaseMultiTableReport expects real and synthetic data to be '
'dictionaries of pandas.DataFrame. If your real and synthetic data are '
'pd.DataFrame, please use the single-table BaseMultiTableReport instead.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_data_format(real_data, synthetic_data)

def test__validate_metadata_format(self):
"""Test the ``_validate_metadata_format`` method.
This test checks that the method raises an error when the metadata is not a dictionnary.
"""
# Setup
base_report = BaseMultiTableReport()
metadata = []

# Run and Assert
expected_message = 'The provided metadata is not a dictionary.'
with pytest.raises(TypeError, match=expected_message):
base_report._validate_metadata_format(metadata)

def test__validate_metadata_format_with_no_tables(self):
"""Test the ``_validate_metadata_format`` method.
This test checks that the method raises an error when the metadata does not contain a
'tables' key.
"""
# Setup
base_report = BaseMultiTableReport()
metadata = {}

# Run and Assert
expected_message = (
'Multi table reports expect metadata to contain a "tables" key with a mapping from '
'table names to metadata for each table.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_metadata_format(metadata)

def test__validate_metadata_format_with_no_columns(self):
"""Test the ``_validate_metadata_format`` method.
This test checks that the method raises an error when the metadata does not contain a
'columns' key.
"""
# Setup
base_report = BaseMultiTableReport()
metadata = {
'tables': {
'Table_1': {}
}
}

# Run and Assert
expected_message = 'The metadata for table "Table_1" is missing a "columns" key.'
with pytest.raises(ValueError, match=expected_message):
base_report._validate_metadata_format(metadata)

def test__validate_relationships(self):
"""Test the ``_validate_relationships`` method."""
# Setup
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,47 @@ def test__validate_data_format(self):

# Run and Assert
expected_message = (
'Single table report BaseReport expects real and synthetic data to be '
'Single table BaseReport expects real and synthetic data to be '
'pandas.DataFrame. If your real and synthetic data are dictionaries of '
'tables, please use the multi-table BaseReport instead.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_data_format(real_data, synthetic_data)

def test__validate_metadata_format(self):
"""Test the ``_validate_metadata_format`` method.
This test checks that the method raises an error when the metadata is not a dictionary.
"""
# Setup
base_report = BaseReport()
metadata = 'metadata'

# Run and Assert
expected_message = (
'The provided metadata is not a dictionary.'
)
with pytest.raises(TypeError, match=expected_message):
base_report._validate_metadata_format(metadata)

def test__validate_metadata_format_no_columns(self):
"""Test the ``_validate_metadata_format`` method.
This test checks that the method raises an error when the metadata does not contain a
'columns' key.
"""
# Setup
base_report = BaseReport()
metadata = {}

# Run and Assert
expected_message = (
'Single table reports expect metadata to contain a "columns" key with a mapping'
' from column names to column informations.'
)
with pytest.raises(ValueError, match=expected_message):
base_report._validate_metadata_format(metadata)

def test__validate_metadata_matches_data(self):
"""Test the ``_validate_metadata_matches_data`` method.
Expand Down

0 comments on commit 9e83c54

Please sign in to comment.