From f150da4401f7f2a3eaa684a234036ef121ed7ae5 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 21 Nov 2023 17:09:34 -0600 Subject: [PATCH] move table_name to generate --- .../multi_table/base_multi_table_report.py | 13 ++++-- .../test_base_multi_table_report.py | 42 +++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 407e2298..3cfb3476 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -67,8 +67,11 @@ def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): self._validate_relationships(real_data, synthetic_data, metadata) - def _validate(self, real_data, synthetic_data, metadata): - """Validate the inputs. + def generate(self, real_data, synthetic_data, metadata, verbose=True): + """Generate report. + + This method generates the report by iterating through each property and calculating + the score for each property. Args: real_data (pandas.DataFrame): @@ -76,10 +79,12 @@ def _validate(self, real_data, synthetic_data, metadata): synthetic_data (pandas.DataFrame): The synthetic data. metadata (dict): - The metadata of the table. + The metadata, which contains each column's data type as well as relationships. + verbose (bool): + Whether or not to print report summary and progress. """ self.table_names = list(metadata['tables'].keys()) - super()._validate(real_data, synthetic_data, metadata) + return super().generate(real_data, synthetic_data, metadata, verbose) def _check_table_names(self, table_name): if table_name not in self.table_names: diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index 8e76dc72..4f0f35d2 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -144,6 +144,48 @@ def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_d mock__validate_metadata_matches_data.assert_has_calls(expected_calls) report._validate_relationships.assert_called_once_with(real_data, synthetic_data, metadata) + @patch('sdmetrics.reports.base_report.BaseReport.generate') + def test_generate(self, mock_generate): + """Test the ``generate`` method.""" + # Setup + real_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + synthetic_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + real_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + synthetic_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + metadata = { + 'tables': { + 'Table_1': { + 'columns': { + 'col1': {}, + }, + }, + 'Table_2': { + 'columns': { + 'col2': {} + }, + }, + }, + } + report = BaseMultiTableReport() + + # Run + report.generate(real_data, synthetic_data, metadata) + + # Assert + mock_generate.assert_called_once_with(real_data, synthetic_data, metadata, True) + def test__check_table_names(self): """Test the ``_check_table_names`` method.""" # Setup