diff --git a/sdmetrics/reports/multi_table/_properties/structure.py b/sdmetrics/reports/multi_table/_properties/structure.py index 1195b592..affabd07 100644 --- a/sdmetrics/reports/multi_table/_properties/structure.py +++ b/sdmetrics/reports/multi_table/_properties/structure.py @@ -1,6 +1,10 @@ """Structure property for multi-table.""" +import plotly.express as px + +from sdmetrics.errors import VisualizationUnavailableError from sdmetrics.reports.multi_table._properties import BaseMultiTableProperty from sdmetrics.reports.single_table._properties import Structure as SingleTableStructure +from sdmetrics.reports.utils import PlotConfig class Structure(BaseMultiTableProperty): @@ -12,3 +16,52 @@ class Structure(BaseMultiTableProperty): _single_table_property = SingleTableStructure _num_iteration_case = 'table' + + def get_visualization(self, table_name=None): + """Return a visualization for each score in the property. + + Args: + table_name: + If a table name is provided, an error is raised. + + Returns: + plotly.graph_objects._figure.Figure + The visualization for the property. + """ + if table_name: + raise VisualizationUnavailableError( + 'The Structure property does not have a supported visualization for' + ' individual tables.' + ) + + average_score = self._compute_average() + fig = px.bar( + data_frame=self.details, + x='Table', + y='Score', + title=f'Data Diagnostic: Structure (Average Score={average_score})', + category_orders={'group': list(self.details['Table'])}, + color='Metric', + color_discrete_map={ + 'TableFormat': PlotConfig.DATACEBO_DARK, + }, + pattern_shape='Metric', + pattern_shape_sequence=[''], + hover_name='Table', + hover_data={ + 'Table': False, + 'Metric': True, + 'Score': True, + }, + ) + + fig.update_yaxes(range=[0, 1]) + + fig.update_layout( + xaxis_categoryorder='total ascending', + plot_bgcolor=PlotConfig.BACKGROUND_COLOR, + margin={'t': 150}, + font={'size': PlotConfig.FONT_SIZE}, + ) + + return fig diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index a61b55e7..c116a2da 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -100,6 +100,9 @@ def get_visualization(self, property_name, table_name=None): plotly.graph_objects._figure.Figure The visualization for the requested property. """ + if property_name == 'Data Structure': + return self._properties[property_name].get_visualization(table_name) + if table_name is None: raise ValueError( 'Please provide a table name to get a visualization for the property.' diff --git a/tests/unit/reports/multi_table/_properties/test_structure.py b/tests/unit/reports/multi_table/_properties/test_structure.py index 63c949df..ddaf9175 100644 --- a/tests/unit/reports/multi_table/_properties/test_structure.py +++ b/tests/unit/reports/multi_table/_properties/test_structure.py @@ -1,4 +1,10 @@ """Test Structure multi-table class.""" +from unittest.mock import Mock, patch + +import pandas as pd +import pytest + +from sdmetrics.errors import VisualizationUnavailableError from sdmetrics.reports.multi_table._properties import Structure from sdmetrics.reports.single_table._properties import Structure as SingleTableStructure @@ -12,3 +18,83 @@ def test__init__(): assert synthesis._properties == {} assert synthesis._single_table_property == SingleTableStructure assert synthesis._num_iteration_case == 'table' + + +@patch('sdmetrics.reports.multi_table._properties.structure.px') +def test_get_visualization(mock_px): + """Test the ``get_visualization`` method.""" + # Setup + structure_property = Structure() + + mock_df = pd.DataFrame({ + 'Table': ['Table1', 'Table2'], + 'Score': [0.7, 0.3], + 'Metric': ['TableFormat', 'TableFormat'] + }) + structure_property.details = mock_df + + mock__compute_average = Mock(return_value=0.5) + structure_property._compute_average = mock__compute_average + + mock_bar = Mock() + mock_px.bar.return_value = mock_bar + + # Run + structure_property.get_visualization() + + # Assert + mock__compute_average.assert_called_once() + + # Expected call + expected_kwargs = { + 'data_frame': mock_df, + 'x': 'Table', + 'y': 'Score', + 'title': ( + 'Data Diagnostic: Structure (Average ' + f'Score={mock__compute_average.return_value})' + ), + 'category_orders': {'group': mock_df['Table'].tolist()}, + 'color': 'Metric', + 'color_discrete_map': { + 'TableFormat': '#000036', + }, + 'pattern_shape': 'Metric', + 'pattern_shape_sequence': [''], + 'hover_name': 'Table', + 'hover_data': { + 'Table': False, + 'Metric': True, + 'Score': True, + }, + } + + # Check call_args of mock_px.bar + _, kwargs = mock_px.bar.call_args + + # Check DataFrame separately + assert kwargs.pop('data_frame').equals(expected_kwargs.pop('data_frame')) + + # Check other arguments + assert kwargs == expected_kwargs + + mock_bar.update_yaxes.assert_called_once_with(range=[0, 1]) + mock_bar.update_layout.assert_called_once_with( + xaxis_categoryorder='total ascending', + plot_bgcolor='#F5F5F8', + margin={'t': 150}, + font={'size': 18} + ) + + +def test_get_visualization_with_table_name(): + """Test the ``get_visualization`` when a table name is given.""" + # Setup + synthesis = Structure() + + # Run and Assert + expected_message = ( + 'The Structure property does not have a supported visualization for individual tables.' + ) + with pytest.raises(VisualizationUnavailableError, match=expected_message): + synthesis.get_visualization('table_name') diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index a5f7c4d4..cbabefa6 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -304,3 +304,18 @@ def test_get_visualization_without_table_name(self): with pytest.raises(ValueError, match=expected_error_message): report.get_visualization('Property_1') + + def test_get_visualization_for_structure_property(self): + """Test the ``get_visualization`` method for the structure property.""" + # Setup + report = BaseMultiTableReport() + report._properties = { + 'Data Structure': Mock() + } + report._properties['Data Structure'].get_visualization = Mock() + + # Run + report.get_visualization('Data Structure', 'Table_1') + + # Assert + report._properties['Data Structure'].get_visualization.assert_called_once_with('Table_1')