From d0ab31b06add82a0b9d53059d1ced2cc3d22b6fb Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 31 Oct 2023 10:30:19 -0500 Subject: [PATCH 1/3] Make quality reports fault tolerant on missing relationships in metadata --- .../reports/multi_table/_properties/base.py | 22 +++-- .../_properties/inter_table_trends.py | 83 ++++++++++--------- .../multi_table/test_quality_report.py | 21 +++++ .../multi_table/_properties/test_base.py | 13 +-- 4 files changed, 83 insertions(+), 56 deletions(-) diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index f585a86e..29f3a971 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -1,4 +1,6 @@ """Multi table base property class.""" +import warnings + import numpy as np import pandas as pd @@ -33,16 +35,24 @@ def _get_num_iterations(self, metadata): elif self._num_iteration_case == 'table': return len(metadata['tables']) elif self._num_iteration_case == 'relationship': - return len(metadata['relationships']) + try: + return len(metadata['relationships']) + except KeyError as e: + message = f'{type(e).__name__}: {e}. No relationships found in the data.' + warnings.warn(message) + return 0 elif self._num_iteration_case == 'column_pair': num_columns = [len(table['columns']) for table in metadata['tables'].values()] return sum([(n_cols * (n_cols - 1)) // 2 for n_cols in num_columns]) elif self._num_iteration_case == 'inter_table_column_pair': iterations = 0 - for relationship in metadata['relationships']: - parent_columns = metadata['tables'][relationship['parent_table_name']]['columns'] - child_columns = metadata['tables'][relationship['child_table_name']]['columns'] - iterations += (len(parent_columns) * len(child_columns)) + if 'relationships' in metadata: + for relationship in metadata['relationships']: + parent_columns = \ + metadata['tables'][relationship['parent_table_name']]['columns'] + child_columns = \ + metadata['tables'][relationship['child_table_name']]['columns'] + iterations += (len(parent_columns) * len(child_columns)) return iterations def _compute_average(self): @@ -51,6 +61,8 @@ def _compute_average(self): has_score_column = 'Score' in self.details.columns assert_message = "The property details must be a DataFrame with a 'Score' column." + if not has_score_column: + return np.nan assert is_dataframe, assert_message assert has_score_column, assert_message diff --git a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py index 4dcad9e1..20d7c839 100644 --- a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py +++ b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py @@ -106,48 +106,49 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No The progress bar object. Defaults to None. """ all_details = [] - for relationship in metadata['relationships']: - parent = relationship['parent_table_name'] - child = relationship['child_table_name'] - foreign_key = relationship['child_foreign_key'] - - denormalized_real, denormalized_synthetic = self._denormalize_tables( - real_data, - synthetic_data, - relationship - ) - - merged_metadata, parent_cols, child_cols = self._merge_metadata( - metadata, - parent, - child - ) - - parent_child_pairs = itertools.product(parent_cols, child_cols) - - self._properties[(parent, child, foreign_key)] = SingleTableColumnPairTrends() - details = self._properties[(parent, child, foreign_key)]._generate_details( - denormalized_real, denormalized_synthetic, merged_metadata, - progress_bar=progress_bar, column_pairs=parent_child_pairs - ) + if 'relationships' in metadata: + for relationship in metadata['relationships']: + parent = relationship['parent_table_name'] + child = relationship['child_table_name'] + foreign_key = relationship['child_foreign_key'] + + denormalized_real, denormalized_synthetic = self._denormalize_tables( + real_data, + synthetic_data, + relationship + ) + + merged_metadata, parent_cols, child_cols = self._merge_metadata( + metadata, + parent, + child + ) + + parent_child_pairs = itertools.product(parent_cols, child_cols) + + self._properties[(parent, child, foreign_key)] = SingleTableColumnPairTrends() + details = self._properties[(parent, child, foreign_key)]._generate_details( + denormalized_real, denormalized_synthetic, merged_metadata, + progress_bar=progress_bar, column_pairs=parent_child_pairs + ) + + details['Parent Table'] = parent + details['Child Table'] = child + details['Foreign Key'] = foreign_key + if not details.empty: + details['Column 1'] = details['Column 1'].str.replace(f'{parent}.', '', n=1) + details['Column 2'] = details['Column 2'].str.replace(f'{child}.', '', n=1) + all_details.append(details) + + self.details = pd.concat(all_details, axis=0).reset_index(drop=True) + detail_columns = [ + 'Parent Table', 'Child Table', 'Foreign Key', 'Column 1', 'Column 2', + 'Metric', 'Score', 'Real Correlation', 'Synthetic Correlation' + ] + if 'Error' in self.details.columns: + detail_columns.append('Error') - details['Parent Table'] = parent - details['Child Table'] = child - details['Foreign Key'] = foreign_key - if not details.empty: - details['Column 1'] = details['Column 1'].str.replace(f'{parent}.', '', n=1) - details['Column 2'] = details['Column 2'].str.replace(f'{child}.', '', n=1) - all_details.append(details) - - self.details = pd.concat(all_details, axis=0).reset_index(drop=True) - detail_columns = [ - 'Parent Table', 'Child Table', 'Foreign Key', 'Column 1', 'Column 2', - 'Metric', 'Score', 'Real Correlation', 'Synthetic Correlation' - ] - if 'Error' in self.details.columns: - detail_columns.append('Error') - - self.details = self.details[detail_columns] + self.details = self.details[detail_columns] def get_visualization(self, table_name=None): """Create a plot to show the inter table trends data. diff --git a/tests/integration/reports/multi_table/test_quality_report.py b/tests/integration/reports/multi_table/test_quality_report.py index 888dbd06..eec51cea 100644 --- a/tests/integration/reports/multi_table/test_quality_report.py +++ b/tests/integration/reports/multi_table/test_quality_report.py @@ -306,3 +306,24 @@ def test_quality_report_with_errors(): assert score == 0.7008862433862433 pd.testing.assert_frame_equal(properties, expected_properties) pd.testing.assert_frame_equal(details_column_shapes, expected_details) + + +def test_quality_report_with_no_relationships(): + # Setup + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + + del metadata['relationships'] + report = QualityReport() + + # Run + report.generate(real_data, synthetic_data, metadata, verbose=True) + score = report.get_score() + + # Assert + expected_properties = pd.DataFrame({ + 'Property': ['Column Shapes', 'Column Pair Trends', 'Cardinality', 'Intertable Trends'], + 'Score': [0.792262, 0.424967, np.nan, np.nan] + }) + properties = report.get_properties() + pd.testing.assert_frame_equal(properties, expected_properties) + assert score == 0.6086142240422239 diff --git a/tests/unit/reports/multi_table/_properties/test_base.py b/tests/unit/reports/multi_table/_properties/test_base.py index 151345bd..1c75b02d 100644 --- a/tests/unit/reports/multi_table/_properties/test_base.py +++ b/tests/unit/reports/multi_table/_properties/test_base.py @@ -1,6 +1,5 @@ """Test BaseMultiTableProperty class.""" -import re from unittest.mock import Mock import numpy as np @@ -176,21 +175,15 @@ def test__generate_details_raises_error(self): with pytest.raises(NotImplementedError): base_property._generate_details(None, None, None, None) - def test__compute_average_raises_error(self): + def test__compute_average_sends_nan(self): """Test that the method raises an error when _details has not been computed.""" # Setup base_property = BaseMultiTableProperty() # Run and Assert - expected_error_message = re.escape( - "The property details must be a DataFrame with a 'Score' column." - ) - with pytest.raises(AssertionError, match=expected_error_message): - base_property._compute_average() - + assert np.isnan(base_property._compute_average()) base_property.details = pd.DataFrame({'Column': ['a', 'b', 'c']}) - with pytest.raises(AssertionError, match=expected_error_message): - base_property._compute_average() + assert np.isnan(base_property._compute_average()) def test_get_score(self): """Test the ``get_score`` method.""" From 6dc6cd1142dfde35a6a77dd8b209dfa592d19da3 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 31 Oct 2023 10:34:00 -0500 Subject: [PATCH 2/3] remove unneeded assert --- sdmetrics/reports/multi_table/_properties/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index 29f3a971..6f0a2fcf 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -59,12 +59,11 @@ def _compute_average(self): """Average the scores for each column.""" is_dataframe = isinstance(self.details, pd.DataFrame) has_score_column = 'Score' in self.details.columns - assert_message = "The property details must be a DataFrame with a 'Score' column." + assert_message = "The property details must be in a DataFrame with a 'Score' column." + assert is_dataframe, assert_message if not has_score_column: return np.nan - assert is_dataframe, assert_message - assert has_score_column, assert_message return self.details['Score'].mean() From 49901a95e8319468aa6d33cfe1e3aa5c2fa2f7ad Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 Nov 2023 12:21:51 -0500 Subject: [PATCH 3/3] Use a one liner to check for relationship data --- .../reports/multi_table/_properties/base.py | 13 ++-- .../_properties/inter_table_trends.py | 66 +++++++++---------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index 6f0a2fcf..d885f0ea 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -46,13 +46,12 @@ def _get_num_iterations(self, metadata): return sum([(n_cols * (n_cols - 1)) // 2 for n_cols in num_columns]) elif self._num_iteration_case == 'inter_table_column_pair': iterations = 0 - if 'relationships' in metadata: - for relationship in metadata['relationships']: - parent_columns = \ - metadata['tables'][relationship['parent_table_name']]['columns'] - child_columns = \ - metadata['tables'][relationship['child_table_name']]['columns'] - iterations += (len(parent_columns) * len(child_columns)) + for relationship in metadata.get('relationships', []): + parent_columns = \ + metadata['tables'][relationship['parent_table_name']]['columns'] + child_columns = \ + metadata['tables'][relationship['child_table_name']]['columns'] + iterations += (len(parent_columns) * len(child_columns)) return iterations def _compute_average(self): diff --git a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py index 20d7c839..5addd8dd 100644 --- a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py +++ b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py @@ -106,40 +106,40 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No The progress bar object. Defaults to None. """ all_details = [] - if 'relationships' in metadata: - for relationship in metadata['relationships']: - parent = relationship['parent_table_name'] - child = relationship['child_table_name'] - foreign_key = relationship['child_foreign_key'] - - denormalized_real, denormalized_synthetic = self._denormalize_tables( - real_data, - synthetic_data, - relationship - ) - - merged_metadata, parent_cols, child_cols = self._merge_metadata( - metadata, - parent, - child - ) - - parent_child_pairs = itertools.product(parent_cols, child_cols) - - self._properties[(parent, child, foreign_key)] = SingleTableColumnPairTrends() - details = self._properties[(parent, child, foreign_key)]._generate_details( - denormalized_real, denormalized_synthetic, merged_metadata, - progress_bar=progress_bar, column_pairs=parent_child_pairs - ) - - details['Parent Table'] = parent - details['Child Table'] = child - details['Foreign Key'] = foreign_key - if not details.empty: - details['Column 1'] = details['Column 1'].str.replace(f'{parent}.', '', n=1) - details['Column 2'] = details['Column 2'].str.replace(f'{child}.', '', n=1) - all_details.append(details) + for relationship in metadata.get('relationships', []): + parent = relationship['parent_table_name'] + child = relationship['child_table_name'] + foreign_key = relationship['child_foreign_key'] + + denormalized_real, denormalized_synthetic = self._denormalize_tables( + real_data, + synthetic_data, + relationship + ) + + merged_metadata, parent_cols, child_cols = self._merge_metadata( + metadata, + parent, + child + ) + + parent_child_pairs = itertools.product(parent_cols, child_cols) + + self._properties[(parent, child, foreign_key)] = SingleTableColumnPairTrends() + details = self._properties[(parent, child, foreign_key)]._generate_details( + denormalized_real, denormalized_synthetic, merged_metadata, + progress_bar=progress_bar, column_pairs=parent_child_pairs + ) + + details['Parent Table'] = parent + details['Child Table'] = child + details['Foreign Key'] = foreign_key + if not details.empty: + details['Column 1'] = details['Column 1'].str.replace(f'{parent}.', '', n=1) + details['Column 2'] = details['Column 2'].str.replace(f'{child}.', '', n=1) + all_details.append(details) + if len(all_details) > 0: self.details = pd.concat(all_details, axis=0).reset_index(drop=True) detail_columns = [ 'Parent Table', 'Child Table', 'Foreign Key', 'Column 1', 'Column 2',