diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index d950e659..ebadbb3c 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -3,12 +3,30 @@ import pandas as pd import plotly.express as px import plotly.figure_factory as ff +import plotly.io as pio from pandas.api.types import is_datetime64_dtype from sdmetrics.reports.utils import PlotConfig from sdmetrics.utils import get_missing_percentage, is_datetime +def _set_plotly_config(): + """Set the ``plotly`` config according to the environment.""" + renderers = list(pio.renderers) + if getattr('get_ipython', __builtin__): + ipython_interpreter = get_ipython() + if 'colab' in ipython_interpreter and 'colab' in renderers: + pio.renderers.default = 'colab' + elif 'ZMQInteractiveShell' in ipython_interpreter and 'notebook' in renderers: + pio.renderers.default = 'notebook' + + elif 'iframe' in renderers: + pio.renderers.default = 'iframe' + + +_set_plotly_config() + + def _generate_column_bar_plot(real_data, synthetic_data, plot_kwargs={}): """Generate a bar plot of the real and synthetic data.