diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index fd408999..bf72906f 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -82,9 +82,6 @@ def _validate(self, real_data, synthetic_data, metadata): self._validate_data_format(real_data, synthetic_data) self._validate_metadata_matches_data(real_data, synthetic_data, metadata) - def _handle_results(self, verbose): - raise NotImplementedError - @staticmethod def convert_datetimes(real_data, synthetic_data, metadata): """Try to convert all datetime columns to datetime dtype. diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 936d70be..3cfb3476 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -60,7 +60,6 @@ def _validate_relationships(self, real_data, synthetic_data, metadata): def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): """Validate that the metadata matches the data.""" - self.table_names = list(metadata['tables'].keys()) for table in self.table_names: super()._validate_metadata_matches_data( real_data[table], synthetic_data[table], metadata['tables'][table] @@ -68,6 +67,25 @@ def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): self._validate_relationships(real_data, synthetic_data, metadata) + def generate(self, real_data, synthetic_data, metadata, verbose=True): + """Generate report. + + This method generates the report by iterating through each property and calculating + the score for each property. + + Args: + real_data (pandas.DataFrame): + The real data. + synthetic_data (pandas.DataFrame): + The synthetic data. + metadata (dict): + The metadata, which contains each column's data type as well as relationships. + verbose (bool): + Whether or not to print report summary and progress. + """ + self.table_names = list(metadata['tables'].keys()) + return super().generate(real_data, synthetic_data, metadata, verbose) + def _check_table_names(self, table_name): if table_name not in self.table_names: raise ValueError(f"Unknown table ('{table_name}'). Must be one of {self.table_names}.") diff --git a/sdmetrics/reports/multi_table/diagnostic_report.py b/sdmetrics/reports/multi_table/diagnostic_report.py index 6ccaa9bf..9428f9d8 100644 --- a/sdmetrics/reports/multi_table/diagnostic_report.py +++ b/sdmetrics/reports/multi_table/diagnostic_report.py @@ -17,3 +17,6 @@ def __init__(self): 'Data Structure': Structure(), 'Relationship Validity': RelationshipValidity() } + + def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): + self._validate_relationships(real_data, synthetic_data, metadata) diff --git a/sdmetrics/reports/single_table/diagnostic_report.py b/sdmetrics/reports/single_table/diagnostic_report.py index 4e36815e..1e96fc81 100644 --- a/sdmetrics/reports/single_table/diagnostic_report.py +++ b/sdmetrics/reports/single_table/diagnostic_report.py @@ -16,3 +16,6 @@ def __init__(self): 'Data Validity': DataValidity(), 'Data Structure': Structure(), } + + def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): + return diff --git a/tests/integration/reports/single_table/test_diagnostic_report.py b/tests/integration/reports/single_table/test_diagnostic_report.py index 1ff5ec35..a4884593 100644 --- a/tests/integration/reports/single_table/test_diagnostic_report.py +++ b/tests/integration/reports/single_table/test_diagnostic_report.py @@ -201,3 +201,40 @@ def test_get_details_with_errors(self): report.get_details('Data Validity'), expected_details ) + + def test_report_runs_with_mismatch_data_metadata(self): + """Test that the report runs with mismatched data and metadata.""" + # Setup + data = pd.DataFrame({ + 'id': [0, 1, 2], + 'val1': ['a', 'a', 'b'], + 'val2': [0.1, 2.4, 5.7] + }) + synthetic_data = pd.DataFrame({ + 'id': [1, 2, 3], + 'extra_col': ['x', 'y', 'z'], + 'val1': ['c', 'd', 'd'] + }) + + metadata = { + 'columns': { + 'id': {'sdtype': 'id'}, + 'val1': {'sdtype': 'categorical'}, + 'val2': {'sdtype': 'numerical'} + }, + 'primary_key': 'id' + } + report = DiagnosticReport() + + # Run + report.generate(data, synthetic_data, metadata) + + # Assert + expected_properties = pd.DataFrame({ + 'Property': ['Data Validity', 'Data Structure'], + 'Score': [0.5, 0.5] + }) + assert report.get_score() == 0.5 + pd.testing.assert_frame_equal( + report.get_properties(), expected_properties + ) diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index 156ac79e..412e1a68 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -131,6 +131,7 @@ def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_d report = BaseMultiTableReport() report._validate_relationships = mock__validate_relationships + report.table_names = ['Table_1', 'Table_2'] # Run report._validate_metadata_matches_data(real_data, synthetic_data, metadata) @@ -143,6 +144,49 @@ def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_d mock__validate_metadata_matches_data.assert_has_calls(expected_calls) report._validate_relationships.assert_called_once_with(real_data, synthetic_data, metadata) + @patch('sdmetrics.reports.base_report.BaseReport.generate') + def test_generate(self, mock_generate): + """Test the ``generate`` method.""" + # Setup + real_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + synthetic_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + real_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + synthetic_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [4, 5, 6]}), + } + metadata = { + 'tables': { + 'Table_1': { + 'columns': { + 'col1': {}, + }, + }, + 'Table_2': { + 'columns': { + 'col2': {} + }, + }, + }, + } + report = BaseMultiTableReport() + + # Run + report.generate(real_data, synthetic_data, metadata) + + # Assert + assert report.table_names == ['Table_1', 'Table_2'] + mock_generate.assert_called_once_with(real_data, synthetic_data, metadata, True) + def test__check_table_names(self): """Test the ``_check_table_names`` method.""" # Setup diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 8e56d474..09b16ea8 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -137,6 +137,36 @@ def test__validate(self): real_data, synthetic_data, metadata ) + def test__validate_with_value_error(self): + """Test the ``_validate`` method with a ValueError.""" + # Setup + base_report = BaseReport() + mock__validate_metadata_matches_data = Mock( + side_effect=ValueError('error message') + ) + base_report._validate_metadata_matches_data = mock__validate_metadata_matches_data + + real_data = pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'], + 'column3': [4, 5, 6] + }) + synthetic_data = pd.DataFrame({ + 'column1': [1, 2, 3], + 'column2': ['a', 'b', 'c'], + 'column4': [4, 5, 6] + }) + metadata = { + 'columns': { + 'column1': {'sdtype': 'numerical'}, + 'column2': {'sdtype': 'categorical'}, + } + } + + # Run and Assert + with pytest.raises(ValueError, match='error message'): + base_report._validate(real_data, synthetic_data, metadata) + def test_convert_datetimes(self): """Test that ``_convert_datetimes`` tries to convert datetime columns.""" # Setup