diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index bf72906f..c934f14f 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -61,13 +61,24 @@ def _validate_data_format(self, real_data, synthetic_data): return error_message = ( - f'Single table report {self.__class__.__name__} expects real and synthetic data to be' + f'Single table {self.__class__.__name__} expects real and synthetic data to be' ' pandas.DataFrame. If your real and synthetic data are dictionaries of tables, ' f'please use the multi-table {self.__class__.__name__} instead.' ) raise ValueError(error_message) + def _validate_metadata_format(self, metadata): + """Validate the metadata.""" + if not isinstance(metadata, dict): + raise TypeError('The provided metadata is not a dictionary.') + + if 'columns' not in metadata: + raise ValueError( + 'Single table reports expect metadata to contain a "columns" key with a mapping' + ' from column names to column informations.' + ) + def _validate(self, real_data, synthetic_data, metadata): """Validate the inputs. @@ -80,6 +91,7 @@ def _validate(self, real_data, synthetic_data, metadata): The metadata of the table. """ self._validate_data_format(real_data, synthetic_data) + self._validate_metadata_format(metadata) self._validate_metadata_matches_data(real_data, synthetic_data, metadata) @staticmethod diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 3cfb3476..e0163318 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -32,13 +32,29 @@ def _validate_data_format(self, real_data, synthetic_data): return error_message = ( - f'Multi table report {self.__class__.__name__} expects real and synthetic data to be' + f'Multi table {self.__class__.__name__} expects real and synthetic data to be' ' dictionaries of pandas.DataFrame. If your real and synthetic data are pd.DataFrame,' f' please use the single-table {self.__class__.__name__} instead.' ) raise ValueError(error_message) + def _validate_metadata_format(self, metadata): + """Validate the metadata.""" + if not isinstance(metadata, dict): + raise TypeError('The provided metadata is not a dictionary.') + + if 'tables' not in metadata: + raise ValueError( + 'Multi table reports expect metadata to contain a "tables" key with a mapping' + ' from table names to metadata for each table.' + ) + for table_name, table_metadata in metadata['tables'].items(): + if 'columns' not in table_metadata: + raise ValueError( + f'The metadata for table "{table_name}" is missing a "columns" key.' + ) + def _validate_relationships(self, real_data, synthetic_data, metadata): """Validate that the relationships are valid.""" for rel in metadata.get('relationships', []): @@ -83,7 +99,7 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): verbose (bool): Whether or not to print report summary and progress. """ - self.table_names = list(metadata['tables'].keys()) + self.table_names = list(metadata.get('tables', {}).keys()) return super().generate(real_data, synthetic_data, metadata, verbose) def _check_table_names(self, table_name): diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index 412e1a68..ff8061ac 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -35,13 +35,64 @@ def test__validate_data_format(self): # Run and Assert expected_message = ( - 'Multi table report BaseMultiTableReport expects real and synthetic data to be ' + 'Multi table BaseMultiTableReport expects real and synthetic data to be ' 'dictionaries of pandas.DataFrame. If your real and synthetic data are ' 'pd.DataFrame, please use the single-table BaseMultiTableReport instead.' ) with pytest.raises(ValueError, match=expected_message): base_report._validate_data_format(real_data, synthetic_data) + def test__validate_metadata_format(self): + """Test the ``_validate_metadata_format`` method. + + This test checks that the method raises an error when the metadata is not a dictionnary. + """ + # Setup + base_report = BaseMultiTableReport() + metadata = [] + + # Run and Assert + expected_message = 'The provided metadata is not a dictionary.' + with pytest.raises(TypeError, match=expected_message): + base_report._validate_metadata_format(metadata) + + def test__validate_metadata_format_with_no_tables(self): + """Test the ``_validate_metadata_format`` method. + + This test checks that the method raises an error when the metadata does not contain a + 'tables' key. + """ + # Setup + base_report = BaseMultiTableReport() + metadata = {} + + # Run and Assert + expected_message = ( + 'Multi table reports expect metadata to contain a "tables" key with a mapping from ' + 'table names to metadata for each table.' + ) + with pytest.raises(ValueError, match=expected_message): + base_report._validate_metadata_format(metadata) + + def test__validate_metadata_format_with_no_columns(self): + """Test the ``_validate_metadata_format`` method. + + This test checks that the method raises an error when the metadata does not contain a + 'columns' key. + """ + # Setup + base_report = BaseMultiTableReport() + metadata = { + 'tables': { + 'Table_1': {} + } + } + + # Run and Assert + expected_message = 'The metadata for table "Table_1" is missing a "columns" key.' + with pytest.raises(ValueError, match=expected_message): + base_report._validate_metadata_format(metadata) + def test__validate_relationships(self): """Test the ``_validate_relationships`` method.""" # Setup diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 09b16ea8..776d8250 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -24,13 +24,47 @@ def test__validate_data_format(self): # Run and Assert expected_message = ( - 'Single table report BaseReport expects real and synthetic data to be ' + 'Single table BaseReport expects real and synthetic data to be ' 'pandas.DataFrame. If your real and synthetic data are dictionaries of ' 'tables, please use the multi-table BaseReport instead.' ) with pytest.raises(ValueError, match=expected_message): base_report._validate_data_format(real_data, synthetic_data) + def test__validate_metadata_format(self): + """Test the ``_validate_metadata_format`` method. + + This test checks that the method raises an error when the metadata is not a dictionary. + """ + # Setup + base_report = BaseReport() + metadata = 'metadata' + + # Run and Assert + expected_message = ( + 'The provided metadata is not a dictionary.' + ) + with pytest.raises(TypeError, match=expected_message): + base_report._validate_metadata_format(metadata) + + def test__validate_metadata_format_no_columns(self): + """Test the ``_validate_metadata_format`` method. + + This test checks that the method raises an error when the metadata does not contain a + 'columns' key. + """ + # Setup + base_report = BaseReport() + metadata = {} + + # Run and Assert + expected_message = ( + 'Single table reports expect metadata to contain a "columns" key with a mapping' + ' from column names to column informations.' + ) + with pytest.raises(ValueError, match=expected_message): + base_report._validate_metadata_format(metadata) + def test__validate_metadata_matches_data(self): """Test the ``_validate_metadata_matches_data`` method.