diff --git a/sdmetrics/reports/multi_table/_properties/__init__.py b/sdmetrics/reports/multi_table/_properties/__init__.py index d43faaf1..2483e59f 100644 --- a/sdmetrics/reports/multi_table/_properties/__init__.py +++ b/sdmetrics/reports/multi_table/_properties/__init__.py @@ -8,6 +8,7 @@ from sdmetrics.reports.multi_table._properties.coverage import Coverage from sdmetrics.reports.multi_table._properties.data_validity import DataValidity from sdmetrics.reports.multi_table._properties.inter_table_trends import InterTableTrends +from sdmetrics.reports.multi_table._properties.relationship_validity import RelationshipValidity from sdmetrics.reports.multi_table._properties.structure import Structure from sdmetrics.reports.multi_table._properties.synthesis import Synthesis @@ -21,5 +22,6 @@ 'InterTableTrends', 'Synthesis', 'Structure', - 'DataValidity' + 'DataValidity', + 'RelationshipValidity', ] diff --git a/sdmetrics/reports/multi_table/_properties/base.py b/sdmetrics/reports/multi_table/_properties/base.py index d885f0ea..1ca70f18 100644 --- a/sdmetrics/reports/multi_table/_properties/base.py +++ b/sdmetrics/reports/multi_table/_properties/base.py @@ -54,6 +54,14 @@ def _get_num_iterations(self, metadata): iterations += (len(parent_columns) * len(child_columns)) return iterations + @staticmethod + def _extract_tuple(data, relation): + parent_data = data[relation['parent_table_name']] + child_data = data[relation['child_table_name']] + return ( + parent_data[relation['parent_primary_key']], child_data[relation['child_foreign_key']] + ) + def _compute_average(self): """Average the scores for each column.""" is_dataframe = isinstance(self.details, pd.DataFrame) diff --git a/sdmetrics/reports/multi_table/_properties/cardinality.py b/sdmetrics/reports/multi_table/_properties/cardinality.py index e3e1726f..e6499da3 100644 --- a/sdmetrics/reports/multi_table/_properties/cardinality.py +++ b/sdmetrics/reports/multi_table/_properties/cardinality.py @@ -20,10 +20,12 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No """Get the average score of cardinality shape similarity in the given tables. Args: - real_data (pandas.DataFrame): - The real data. - synthetic_data (pandas.DataFrame): - The synthetic data. + real_data (dict[str, pandas.DataFrame]): + The tables from the real dataset, passed as a dictionary of + table names and pandas.DataFrames. + synthetic_data (dict[str, pandas.DataFrame]): + The tables from the synthetic dataset, passed as a dictionary of + table names and pandas.DataFrames. metadata (dict): The metadata, which contains each column's data type as well as relationships. progress_bar (tqdm.tqdm or None): diff --git a/sdmetrics/reports/multi_table/_properties/relationship_validity.py b/sdmetrics/reports/multi_table/_properties/relationship_validity.py new file mode 100644 index 00000000..80b9ae3a --- /dev/null +++ b/sdmetrics/reports/multi_table/_properties/relationship_validity.py @@ -0,0 +1,137 @@ +import numpy as np +import pandas as pd +import plotly.express as px + +from sdmetrics.column_pairs.statistical import CardinalityBoundaryAdherence, ReferentialIntegrity +from sdmetrics.reports.multi_table._properties.base import BaseMultiTableProperty +from sdmetrics.reports.utils import PlotConfig + + +class RelationshipValidity(BaseMultiTableProperty): + """``Relationship Validity`` property. + + This property measures the validity of the relationship + from the primary key and the foreign key perspective. + + """ + + _num_iteration_case = 'relationship' + + def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=None): + """Generate the _details dataframe for the relationship validity property. + + Args: + real_data (dict[str, pandas.DataFrame]): + The tables from the real dataset, passed as a dictionary of + table names and pandas.DataFrames. + synthetic_data (dict[str, pandas.DataFrame]): + The tables from the synthetic dataset, passed as a dictionary of + table names and pandas.DataFrames. + metadata (dict): + The metadata, which contains each column's data type as well as relationships. + progress_bar (tqdm.tqdm or None): + The progress bar object. Defaults to ``None``. + + Returns: + float: + The average score for the property for all the individual metric scores computed. + """ + child_tables, parent_tables = [], [] + primary_key, foreign_key = [], [] + metric_names, scores, error_messages = [], [], [] + metrics = [ReferentialIntegrity, CardinalityBoundaryAdherence] + for relation in metadata.get('relationships', []): + real_columns = self._extract_tuple(real_data, relation) + synthetic_columns = self._extract_tuple(synthetic_data, relation) + for metric in metrics: + try: + relation_score = metric.compute( + real_columns, + synthetic_columns, + ) + error_message = None + except Exception as e: + relation_score = np.nan + error_message = f'{type(e).__name__}: {e}' + + child_tables.append(relation['child_table_name']) + parent_tables.append(relation['parent_table_name']) + primary_key.append(relation['parent_primary_key']) + foreign_key.append(relation['child_foreign_key']) + metric_names.append(metric.__name__) + scores.append(relation_score) + error_messages.append(error_message) + + if progress_bar: + progress_bar.update() + + self.details = pd.DataFrame({ + 'Parent Table': parent_tables, + 'Child Table': child_tables, + 'Primary Key': primary_key, + 'Foreign Key': foreign_key, + 'Metric': metric_names, + 'Score': scores, + 'Error': error_messages, + }) + + def _get_table_relationships_plot(self, table_name): + """Get the table relationships plot from the parent child relationship scores for a table. + + Args: + table_name (str): + Table name to get details table for. + + Returns: + plotly.graph_objects._figure.Figure + """ + plot_data = self.get_details(table_name).copy() + column_name = 'Child → Parent Relationship' + plot_data[column_name] = ( + plot_data['Child Table'] + ' (' + plot_data['Foreign Key'] + ') → ' + + plot_data['Parent Table'] + ) + plot_data = plot_data.drop(['Child Table', 'Parent Table'], axis=1) + + average_score = round(plot_data['Score'].mean(), 2) + + fig = px.bar( + plot_data, + x='Child → Parent Relationship', + y='Score', + title=f'Data Diagnostic: Relationship Validity (Average Score={average_score})', + color='Metric', + color_discrete_sequence=[PlotConfig.DATACEBO_DARK, PlotConfig.DATACEBO_GREEN], + pattern_shape='Metric', + pattern_shape_sequence=['', '/'], + hover_name='Child → Parent Relationship', + hover_data={ + 'Child → Parent Relationship': False, + 'Metric': True, + 'Score': True, + }, + barmode='group' + ) + + fig.update_yaxes(range=[0, 1]) + + fig.update_layout( + xaxis_categoryorder='total ascending', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + font={'size': PlotConfig.FONT_SIZE} + ) + + return fig + + def get_visualization(self, table_name): + """Return a visualization for each score in the property. + + Args: + table_name (str): + Table name to get the visualization for. + + Returns: + plotly.graph_objects._figure.Figure + The visualization for the property. + """ + return self._get_table_relationships_plot(table_name) diff --git a/tests/integration/reports/multi_table/_properties/test_relationship_validity.py b/tests/integration/reports/multi_table/_properties/test_relationship_validity.py new file mode 100644 index 00000000..6ccbbc77 --- /dev/null +++ b/tests/integration/reports/multi_table/_properties/test_relationship_validity.py @@ -0,0 +1,41 @@ +import sys + +from tqdm import tqdm + +from sdmetrics.demos import load_demo +from sdmetrics.reports.multi_table._properties import RelationshipValidity + + +class TestRelationshipValidity: + + def test_end_to_end(self): + """Test the ``RelationshipValidity`` multi-table property end to end.""" + # Setup + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + relationship_validity = RelationshipValidity() + + # Run + result = relationship_validity.get_score(real_data, synthetic_data, metadata) + + # Assert + assert result == 1.0 + + def test_with_progress_bar(self, capsys): + """Test that the progress bar is correctly updated.""" + # Setup + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + relationship_validity = RelationshipValidity() + num_relationship = 2 + + progress_bar = tqdm(total=num_relationship, file=sys.stdout) + + # Run + result = relationship_validity.get_score(real_data, synthetic_data, metadata, progress_bar) + progress_bar.close() + captured = capsys.readouterr() + output = captured.out + + # Assert + assert result == 1.0 + assert '100%' in output + assert f'{num_relationship}/{num_relationship}' in output diff --git a/tests/unit/reports/multi_table/_properties/test_base.py b/tests/unit/reports/multi_table/_properties/test_base.py index 1c75b02d..21f0c69e 100644 --- a/tests/unit/reports/multi_table/_properties/test_base.py +++ b/tests/unit/reports/multi_table/_properties/test_base.py @@ -90,6 +90,36 @@ def test__get_num_iterations(self): base_property._num_iteration_case = 'inter_table_column_pair' assert base_property._get_num_iterations(metadata) == 11 + def test__extract_tuple(self): + """Test the ``_extract_tuple`` method.""" + # Setup + base_property = BaseMultiTableProperty() + real_user_df = pd.DataFrame({ + 'user_id': ['user1', 'user2'], + 'columnA': ['A', 'B'], + 'columnB': [np.nan, 1.0] + }) + real_session_df = pd.DataFrame({ + 'session_id': ['session1', 'session2', 'session3'], + 'user_id': ['user1', 'user1', 'user2'], + 'columnC': ['X', 'Y', 'Z'], + 'columnD': [4.0, 6.0, 7.0] + }) + + real_data = {'users': real_user_df, 'sessions': real_session_df} + relation = { + 'parent_table_name': 'users', + 'child_table_name': 'sessions', + 'parent_primary_key': 'user_id', + 'child_foreign_key': 'user_id' + } + + # Run + real_columns = base_property._extract_tuple(real_data, relation) + + # Assert + assert real_columns == (real_data['users']['user_id'], real_data['sessions']['user_id']) + def test__generate_details_property(self): """Test the ``_generate_details`` method.""" # Setup diff --git a/tests/unit/reports/multi_table/_properties/test_relationship_validity.py b/tests/unit/reports/multi_table/_properties/test_relationship_validity.py new file mode 100644 index 00000000..2c8b1b37 --- /dev/null +++ b/tests/unit/reports/multi_table/_properties/test_relationship_validity.py @@ -0,0 +1,310 @@ +"""Test multi-table RelationshipValidity properties.""" + +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd +import pytest +from plotly.graph_objects import Figure + +from sdmetrics.reports.multi_table._properties.relationship_validity import RelationshipValidity + + +@pytest.fixture() +def real_data_fixture(): + real_user_df = pd.DataFrame({ + 'user_id': ['user1', 'user2'], + 'columnA': ['A', 'B'], + 'columnB': [np.nan, 1.0] + }) + real_session_df = pd.DataFrame({ + 'session_id': ['session1', 'session2', 'session3'], + 'user_id': ['user1', 'user1', 'user2'], + 'columnC': ['X', 'Y', 'Z'], + 'columnD': [4.0, 6.0, 7.0] + }) + return {'users': real_user_df, 'sessions': real_session_df} + + +@pytest.fixture() +def synthetic_data_fixture(): + synthetic_user_df = pd.DataFrame({ + 'user_id': ['user1', 'user2'], + 'columnA': ['A', 'A'], + 'columnB': [0.5, np.nan] + }) + synthetic_session_df = pd.DataFrame({ + 'session_id': ['session1', 'session2', 'session3'], + 'user_id': ['user1', 'user1', 'user2'], + 'columnC': ['X', 'Z', 'Y'], + 'columnD': [3.6, 5.0, 6.0] + }) + return {'users': synthetic_user_df, 'sessions': synthetic_session_df} + + +@pytest.fixture() +def metadata_fixture(): + return { + 'tables': { + 'users': { + 'primary_key': 'user_id', + 'columns': { + 'user_id': {'sdtype': 'id'}, + 'columnA': {'sdtype': 'categorical'}, + 'columnB': {'sdtype': 'numerical'} + }, + }, + 'sessions': { + 'primary_key': 'session_id', + 'columns': { + 'session_id': {'sdtype': 'id'}, + 'user_id': {'sdtype': 'id'}, + 'columnC': {'sdtype': 'categorical'}, + 'columnD': {'sdtype': 'numerical'} + } + } + }, + 'relationships': [ + { + 'parent_table_name': 'users', + 'child_table_name': 'sessions', + 'parent_primary_key': 'user_id', + 'child_foreign_key': 'user_id' + } + ] + } + + +class TestRelationshipValidity: + + def test__extract_tuple(self, real_data_fixture): + """Test the ``_extract_tuple`` method.""" + # Setup + relationship_validity = RelationshipValidity() + real_data = real_data_fixture + relation = { + 'parent_table_name': 'users', + 'child_table_name': 'sessions', + 'parent_primary_key': 'user_id', + 'child_foreign_key': 'user_id' + } + + # Run + real_columns = relationship_validity._extract_tuple(real_data, relation) + + # Assert + assert real_columns == (real_data['users']['user_id'], real_data['sessions']['user_id']) + + def test__get_num_iteration(self): + """Test the ``_get_num_iterations`` method.""" + # Setup + metadata = { + 'relationships': [ + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'col1', + 'child_table_name': 'table2', + 'child_foreign_key': 'col6' + }, + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'col1', + 'child_table_name': 'table3', + 'child_foreign_key': 'col7' + }, + { + 'parent_table_name': 'table2', + 'parent_primary_key': 'col6', + 'child_table_name': 'table4', + 'child_foreign_key': 'col8' + }, + ] + } + relationship_validity = RelationshipValidity() + + # Run + num_iterations = relationship_validity._get_num_iterations(metadata) + + # Assert + assert num_iterations == 3 + + @patch('sdmetrics.reports.multi_table._properties.relationship_validity.' + 'CardinalityBoundaryAdherence') + @patch('sdmetrics.reports.multi_table._properties.relationship_validity.ReferentialIntegrity') + def test_get_score( + self, mock_referentialintegrity, mock_cardinalityboundaryadherence, + real_data_fixture, synthetic_data_fixture, metadata_fixture + ): + """Test the ``get_score`` function. + + Test that when given a ``progress_bar`` and relationships, this calls + ``CardinalityBoundaryAdherence`` and ``ReferentialIntegrity`` compute + method for each relationship. + """ + # Setup + mock_referentialintegrity.compute.return_value = 0.7 + mock_referentialintegrity.__name__ = 'ReferentialIntegrity' + mock_cardinalityboundaryadherence.compute.return_value = 0.3 + mock_cardinalityboundaryadherence.__name__ = 'CardinalityBoundaryAdherence' + mock_compute_average = Mock(return_value=0.5) + relationship_validity = RelationshipValidity() + relationship_validity._compute_average = mock_compute_average + progress_bar = Mock() + + real_data = real_data_fixture + synthetic_data = synthetic_data_fixture + metadata = metadata_fixture + + # Run + score = relationship_validity.get_score( + real_data=real_data, synthetic_data=synthetic_data, metadata=metadata, + progress_bar=progress_bar + ) + + # Assert + expected_details_property = pd.DataFrame({ + 'Parent Table': ['users', 'users'], + 'Child Table': ['sessions', 'sessions'], + 'Primary Key': ['user_id', 'user_id'], + 'Foreign Key': ['user_id', 'user_id'], + 'Metric': ['ReferentialIntegrity', 'CardinalityBoundaryAdherence'], + 'Score': [0.7, 0.3], + }) + + assert score == 0.5 + progress_bar.update.assert_called() + progress_bar.update.assert_called_once() + mock_compute_average.assert_called_once() + pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property) + + @patch('sdmetrics.reports.multi_table._properties.relationship_validity.' + 'CardinalityBoundaryAdherence') + @patch('sdmetrics.reports.multi_table._properties.relationship_validity.ReferentialIntegrity') + def test_get_score_raises_errors( + self, mock_referentialintegrity, mock_cardinalityboundaryadherence, + real_data_fixture, synthetic_data_fixture, metadata_fixture + ): + """Test the ``get_score`` when ``ReferentialIntegrity`` or + ``CardinalityBoundaryAdherence`` crashes""" + # Setup + mock_referentialintegrity.compute.side_effect = [ValueError('error 1')] + mock_referentialintegrity.__name__ = 'ReferentialIntegrity' + mock_cardinalityboundaryadherence.compute.side_effect = [ValueError('error 2')] + mock_cardinalityboundaryadherence.__name__ = 'CardinalityBoundaryAdherence' + relationship_validity = RelationshipValidity() + progress_bar = Mock() + + real_data = real_data_fixture + synthetic_data = synthetic_data_fixture + metadata = metadata_fixture + + # Run + score = relationship_validity.get_score( + real_data=real_data, synthetic_data=synthetic_data, metadata=metadata, + progress_bar=progress_bar + ) + + # Assert + expected_details_property = pd.DataFrame({ + 'Parent Table': ['users', 'users'], + 'Child Table': ['sessions', 'sessions'], + 'Primary Key': ['user_id', 'user_id'], + 'Foreign Key': ['user_id', 'user_id'], + 'Metric': ['ReferentialIntegrity', 'CardinalityBoundaryAdherence'], + 'Score': [np.nan, np.nan], + 'Error': ['ValueError: error 1', 'ValueError: error 2'] + }) + + assert pd.isna(score) + pd.testing.assert_frame_equal(relationship_validity.details, expected_details_property) + progress_bar.update.assert_called() + progress_bar.update.assert_called_once() + + def test_get_details_with_table_name(self): + """Test the ``get_details`` method. + + Test that the method returns the correct details for the given table name, + either from the child or parent table. + """ + # Setup + relationship_validity = RelationshipValidity() + relationship_validity.details = pd.DataFrame({ + 'Child Table': ['users_child', 'sessions_child'], + 'Parent Table': ['users_parent', 'sessions_parent'], + 'Primary Key': ['user_id', 'user_id'], + 'Foreign Key': ['user_id', 'user_id'], + 'Metric': ['ReferentialIntegrity', 'CardinalityBoundaryAdherence'], + 'Score': [1.0, 0.5], + 'Error': [None, 'Some error'] + }) + + # Run + details_users_child = relationship_validity.get_details('users_child') + details_sessions_parent = relationship_validity.get_details('sessions_parent') + + # Assert for child table + assert details_users_child.equals(pd.DataFrame({ + 'Child Table': ['users_child'], + 'Parent Table': ['users_parent'], + 'Primary Key': ['user_id'], + 'Foreign Key': ['user_id'], + 'Metric': ['ReferentialIntegrity'], + 'Score': [1.0], + 'Error': [None] + }, index=[0])) + + # Assert for parent table + assert details_sessions_parent.equals(pd.DataFrame({ + 'Child Table': ['sessions_child'], + 'Parent Table': ['sessions_parent'], + 'Primary Key': ['user_id'], + 'Foreign Key': ['user_id'], + 'Metric': ['CardinalityBoundaryAdherence'], + 'Score': [0.5], + 'Error': ['Some error'] + }, index=[1])) + + def test_get_table_relationships_plot(self): + """Test the ``_get_table_relationships_plot`` method. + + Test that the method returns the correct plotly figure for the given table name. + """ + # Setup + instance = RelationshipValidity() + instance.details = pd.DataFrame({ + 'Child Table': ['users_child', 'sessions_child'], + 'Parent Table': ['users_parent', 'sessions_parent'], + 'Primary Key': ['user_id', 'user_id'], + 'Foreign Key': ['user_id', 'user_id'], + 'Metric': ['ReferentialIntegrity', 'CardinalityBoundaryAdherence'], + 'Score': [1.0, 0.5], + 'Error': [None, 'Some error'] + }) + + # Run + fig = instance._get_table_relationships_plot('users_child') + + # Assert + assert isinstance(fig, Figure) + + expected_x = ['users_child (user_id) → users_parent'] + expected_y = [1.0] + expected_title = 'Data Diagnostic: Relationship Validity (Average Score=1.0)' + + assert fig.data[0].x.tolist() == expected_x + assert fig.data[0].y.tolist() == expected_y + assert fig.layout.title.text == expected_title + + def test_get_visualization(self): + """Test the ``get_visualization`` method.""" + # Setup + mock__get_table_relationships_plot = Mock(side_effect=[Figure()]) + relationship_validity = RelationshipValidity() + relationship_validity._get_table_relationships_plot = mock__get_table_relationships_plot + + # Run + fig = relationship_validity.get_visualization('table_name') + + # Assert + assert isinstance(fig, Figure) + mock__get_table_relationships_plot.assert_called_once_with('table_name')