Skip to content

Commit

Permalink
Update formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Nov 8, 2023
1 parent 4341d76 commit cca85a2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
30 changes: 15 additions & 15 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,17 +507,21 @@ def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annot
pd.api.types.is_numeric_dtype(all_data[x_axis])):
raise ValueError(
f"Sequence Index '{x_axis}' must contain numerical or datetime values only")
if not (is_datetime(all_data[y_axis]) or
pd.api.types.is_numeric_dtype(all_data[y_axis])):
raise ValueError(
f"Column Name '{y_axis}' must contain numerical or datetime values only")

fig = px.line(all_data, x=x_axis, y=y_axis, color=marker,
color_discrete_map={
'Real': PlotConfig.DATACEBO_DARK,
'Synthetic': PlotConfig.DATACEBO_GREEN
})
if annotations:
fig.add_annotation(annotations)

if x_axis == 'sequence_index':
fig.update_xaxes(
title_text='Sequence Position'
)
fig.update_xaxes(title_text='Sequence Position')

fig.update_layout(
title_text=f"Real vs Synthetic Data for column: '{y_axis}'",
Expand Down Expand Up @@ -599,12 +603,6 @@ def get_column_line_plot(real_data, synthetic_data, column_name, metadata):
real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]

# Check if the column is the appropriate type
if not (is_datetime(real_column) or is_datetime(synthetic_column)
or pd.api.types.is_numeric_dtype(real_column) or
pd.api.types.is_numeric_dtype(synthetic_column)):
raise ValueError(f"Column '{column_name}' must contain numerical or datetime values only")

missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)
show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0
Expand Down Expand Up @@ -655,10 +653,12 @@ def get_column_line_plot(real_data, synthetic_data, column_name, metadata):
s_data[marker_name] = 'Synthetic'

# Generate plot
fig = _generate_line_plot(real_data=r_data,
synthetic_data=s_data,
x_axis=x_axis,
y_axis=y_axis,
marker=marker_name,
annotations=annotations)
fig = _generate_line_plot(
real_data=r_data,
synthetic_data=s_data,
x_axis=x_axis,
y_axis=y_axis,
marker=marker_name,
annotations=annotations
)
return fig
14 changes: 13 additions & 1 deletion tests/unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def test__generate_line_plot(px_mock):
assert mock_figure.for_each_annotation.called_once()
assert fig == mock_figure

# Setup failing case
# Setup failing case sequence index
bad_data = pd.DataFrame({
'colX': [1, 'bad_value', 4, 5],
'colY': [6, 7, 9, 18],
Expand All @@ -524,6 +524,18 @@ def test__generate_line_plot(px_mock):
with pytest.raises(ValueError, match=match):
_generate_line_plot(real_data, bad_data, x_axis='colX', y_axis='colY', marker='Data')

# Setup failing case for column
bad_column = pd.DataFrame({
'colX': [1, 2, 4, 5],
'colY': [6, 'bad_value', 9, 18],
'Data': ['Synthetic'] * 4
})

# Run and Assert
match = "Column Name 'colY' must contain numerical or datetime values only"
with pytest.raises(ValueError, match=match):
_generate_line_plot(real_data, bad_column, x_axis='colX', y_axis='colY', marker='Data')


@patch('sdmetrics.visualization.px')
def test__generate_box_plot(px_mock):
Expand Down

0 comments on commit cca85a2

Please sign in to comment.