From cca85a2d83f1ad059fe9e51edd9a9826545119d7 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 8 Nov 2023 12:22:01 -0600 Subject: [PATCH] Update formatting --- sdmetrics/visualization.py | 30 +++++++++++++++--------------- tests/unit/test_visualization.py | 14 +++++++++++++- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index f25d1ce9..2ee3f4f7 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -507,6 +507,11 @@ def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annot pd.api.types.is_numeric_dtype(all_data[x_axis])): raise ValueError( f"Sequence Index '{x_axis}' must contain numerical or datetime values only") + if not (is_datetime(all_data[y_axis]) or + pd.api.types.is_numeric_dtype(all_data[y_axis])): + raise ValueError( + f"Column Name '{y_axis}' must contain numerical or datetime values only") + fig = px.line(all_data, x=x_axis, y=y_axis, color=marker, color_discrete_map={ 'Real': PlotConfig.DATACEBO_DARK, @@ -514,10 +519,9 @@ def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annot }) if annotations: fig.add_annotation(annotations) + if x_axis == 'sequence_index': - fig.update_xaxes( - title_text='Sequence Position' - ) + fig.update_xaxes(title_text='Sequence Position') fig.update_layout( title_text=f"Real vs Synthetic Data for column: '{y_axis}'", @@ -599,12 +603,6 @@ def get_column_line_plot(real_data, synthetic_data, column_name, metadata): real_column = real_data[column_name] synthetic_column = synthetic_data[column_name] - # Check if the column is the appropriate type - if not (is_datetime(real_column) or is_datetime(synthetic_column) - or pd.api.types.is_numeric_dtype(real_column) or - pd.api.types.is_numeric_dtype(synthetic_column)): - raise ValueError(f"Column '{column_name}' must contain numerical or datetime values only") - missing_data_real = get_missing_percentage(real_column) missing_data_synthetic = get_missing_percentage(synthetic_column) show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0 @@ -655,10 +653,12 @@ def get_column_line_plot(real_data, synthetic_data, column_name, metadata): s_data[marker_name] = 'Synthetic' # Generate plot - fig = _generate_line_plot(real_data=r_data, - synthetic_data=s_data, - x_axis=x_axis, - y_axis=y_axis, - marker=marker_name, - annotations=annotations) + fig = _generate_line_plot( + real_data=r_data, + synthetic_data=s_data, + x_axis=x_axis, + y_axis=y_axis, + marker=marker_name, + annotations=annotations + ) return fig diff --git a/tests/unit/test_visualization.py b/tests/unit/test_visualization.py index 9afc63e7..c1ff81f8 100644 --- a/tests/unit/test_visualization.py +++ b/tests/unit/test_visualization.py @@ -512,7 +512,7 @@ def test__generate_line_plot(px_mock): assert mock_figure.for_each_annotation.called_once() assert fig == mock_figure - # Setup failing case + # Setup failing case sequence index bad_data = pd.DataFrame({ 'colX': [1, 'bad_value', 4, 5], 'colY': [6, 7, 9, 18], @@ -524,6 +524,18 @@ def test__generate_line_plot(px_mock): with pytest.raises(ValueError, match=match): _generate_line_plot(real_data, bad_data, x_axis='colX', y_axis='colY', marker='Data') + # Setup failing case for column + bad_column = pd.DataFrame({ + 'colX': [1, 2, 4, 5], + 'colY': [6, 'bad_value', 9, 18], + 'Data': ['Synthetic'] * 4 + }) + + # Run and Assert + match = "Column Name 'colY' must contain numerical or datetime values only" + with pytest.raises(ValueError, match=match): + _generate_line_plot(real_data, bad_column, x_axis='colX', y_axis='colY', marker='Data') + @patch('sdmetrics.visualization.px') def test__generate_box_plot(px_mock):